[feat][assistant][I40FFO] add new Ascend operator ApplyKerasMomentum

This commit is contained in:
zorax 2021-07-21 21:14:12 +08:00
parent b690ae1b7a
commit 6098c1da5e
8 changed files with 303 additions and 1 deletions

View File

@ -412,6 +412,7 @@ inline const PrimitivePtr kPrimHShrink = std::make_shared<Primitive>("HShrink");
inline const PrimitivePtr kPrimHShrinkGrad = std::make_shared<Primitive>("HShrinkGrad");
inline const PrimitivePtr kPrimApplyAdagradDA = std::make_shared<Primitive>("ApplyAdagradDA");
inline const PrimitivePtr kPrimSparseApplyRMSProp = std::make_shared<Primitive>("SparseApplyRMSProp");
inline const PrimitivePtr kPrimApplyKerasMomentum = std::make_shared<Primitive>("ApplyKerasMomentum");
// Comm ops
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");

View File

@ -0,0 +1,90 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/apply_keras_momentum.h"
#include <algorithm>
#include <set>
#include "ops/op_utils.h"
#include "abstract/primitive_infer_map.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
namespace ops {
namespace {
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 5, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto var_shape = input_args[0]->BuildShape();
auto accum_shape = input_args[1]->BuildShape();
auto lr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->GetShapeTrack())[kShape];
auto grad_shape = input_args[3]->BuildShape();
auto momentum_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->GetShapeTrack())[kShape];
// lr, momentum must be scalar
(void)CheckAndConvertUtils::CheckInteger("lr_shape size", lr_shape.size(), kEqual, 0, prim_name);
(void)CheckAndConvertUtils::CheckInteger("momentum_shape size", momentum_shape.size(), kEqual, 0, prim_name);
// var, accum and grad must have the same shape
std::map<std::string, abstract::BaseShapePtr> same_shape_args_map;
same_shape_args_map.insert({"accum", accum_shape});
same_shape_args_map.insert({"grad", grad_shape});
for (auto &elem : same_shape_args_map) {
if (*elem.second != *var_shape) {
MS_EXCEPTION(ValueError) << prim_name << " evaluator arg " << elem.first << " shape " << elem.second->ToString()
<< " are not consistent with var shape " << var_shape->ToString();
}
}
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{var_shape, accum_shape});
}
TuplePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 5, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto var_type = input_args[0]->BuildType();
auto accum_type = input_args[1]->BuildType();
auto lr_type = input_args[2]->BuildType();
auto grad_type = input_args[3]->BuildType();
auto momentum_type = input_args[4]->BuildType();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
// var, accum and grad must have the same type
std::map<std::string, TypePtr> args;
args.insert({"var", var_type});
args.insert({"accum", accum_type});
args.insert({"grad", grad_type});
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
// lr, momentum type must be valid
(void)CheckAndConvertUtils::CheckTensorTypeValid("lr_dtype", lr_type, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("momentum_dtype", momentum_type, valid_types, prim_name);
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, accum_type});
}
} // namespace
AbstractBasePtr ApplyKerasMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(ApplyKerasMomentum, prim::kPrimApplyKerasMomentum, ApplyKerasMomentumInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,47 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_APPLY_KERAS_MOMENTUM_H_
#define MINDSPORE_CORE_OPS_APPLY_KERAS_MOMENTUM_H_
#include <map>
#include <memory>
#include <vector>
#include <string>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameApplyKerasMomentum = "ApplyKerasMomentum";
class MS_CORE_API ApplyKerasMomentum : public PrimitiveC {
public:
ApplyKerasMomentum() : PrimitiveC(kNameApplyKerasMomentum) {
InitIOName({"var", "accum", "lr", "grad", "momentum"}, {"var", "accum"});
}
~ApplyKerasMomentum() = default;
MS_DECLARE_PARENT(ApplyKerasMomentum, PrimitiveC);
};
AbstractBasePtr ApplyKerasMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimApplyKerasMomentumPtr = std::shared_ptr<ApplyKerasMomentum>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_APPLY_KERAS_MOMENTUM_H_

View File

@ -28,6 +28,7 @@ from .add_n import _add_n_tbe
from .add_n_ds import _add_n_ds_tbe
from .accumulate_n_v2 import _accumulate_n_v2_tbe
from .apply_ftrl import _apply_ftrl_tbe
from .apply_keras_momentum import _apply_keras_momentum_tbe
from .apply_momentum import _apply_momentum_tbe
from .apply_adam import _apply_adam_tbe
from .apply_ada_max import _apply_ada_max_tbe

View File

@ -0,0 +1,49 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ApplyKerasMomentum op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
apply_keras_momentum_op_info = TBERegOp("ApplyKerasMomentum") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("apply_keras_momentum_d.so") \
.compute_cost(10) \
.kernel_name("apply_keras_momentum_d") \
.partial_flag(True) \
.attr("use_locking", "optional", "bool", "true,false", "false") \
.attr("use_nesterov", "optional", "bool", "true,false", "false") \
.input(0, "var", False, "required", "all") \
.input(1, "accum", False, "required", "all") \
.input(2, "lr", False, "required", "all") \
.input(3, "grad", False, "required", "all") \
.input(4, "momentum", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "accum", False, "required", "all") \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD,
DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0,
DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \
.get_op_info()
@op_info_register(apply_keras_momentum_op_info)
def _apply_keras_momentum_tbe():
"""ApplyKerasMomentum TBE register"""
return

View File

@ -76,7 +76,7 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa
MaxPool, DataFormatDimMap,
AvgPool, Conv2DBackpropInput, ComputeAccidentalHits,
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
ResizeBilinear, Sigmoid, SeLU, HShrink,
ResizeBilinear, Sigmoid, SeLU, HShrink, ApplyKerasMomentum,
SigmoidCrossEntropyWithLogits, NLLLoss, BCEWithLogitsLoss,
SmoothL1Loss, SoftMarginLoss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2,
SoftmaxCrossEntropyWithLogits, ROIAlign,
@ -185,6 +185,7 @@ __all__ = [
'NLLLoss',
'SGD',
'ApplyMomentum',
"ApplyKerasMomentum",
'FusedWeightScaleApplyMomentum',
'ExpandDims',
'Cast',

View File

@ -9301,3 +9301,98 @@ class SparseApplyRMSProp(Primitive):
self.epsilon = validator.check_number("epsilon", epsilon, 0.0, Rel.GT, self.name)
self.momentum = validator.check_number("momentum", momentum, 0.0, Rel.GE, self.name)
self.rho = validator.check_float_range(rho, 0.0, 1.0, Rel.INC_BOTH, "rho", self.name)
class ApplyKerasMomentum(Primitive):
r"""
Update `var` according to the momentum scheme.
.. math::
\begin{array}{ll} \\
accum = accum * momentum - grad * lr \\
var =
\begin{cases}
var + accum * momentum - grad * lr, &\text{if use_nesterov} \\
var + accum, &\text{else}
\end{cases}
\end{array}
Refer to the paper `On the importance of initialization and momentum in deep
learning <https://dl.acm.org/doi/10.5555/3042817.3043064>`_ for more details.
Inputs of `var`, `accum` and `grad` comply with the implicit type conversion rules
to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
use_locking (bool): If `True`, updating of the `var` and `accum` tensors will be protected by a lock;
Otherwise the behavior is undefined, but may exhibit less contention. Default: False.
use_nesterov (bool): If `True`, the tensor passed to compute grad will be var + momentum * accum,
so in the end, the var you get is actually var + momentum * accum. Default: False.
Inputs:
- **var** (Parameter) - Variable to be updated. With float16 or float32 data type.
- **accum** (Parameter) - Must have the same shape and type as `var`. With float16 or float32 data type.
- **lr** (Union[Number, Tensor]) - Scaling factor. Must be a scalar. With float16 or float32 data type.
- **grad** (Tensor) - The gradient. Must have the same shape and type as `var`.
With float16 or float32 data type.
- **momentum** (Union[Number, Tensor]) - Momentum. Must be a scalar. With float16 or float32 data type.
Outputs:
Tuple of 2 Tensors, the updated parameters.
- **var** (Tensor) - The same shape and data type as `var`.
- **accum** (Tensor) - The same shape and data type as `accum`.
Raises:
TypeError: If the use_locking or use_nesterov is not a bool.
TypeError: If `var` or `accum` is not a Parameter.
TypeError: If `lr` is neither a Number nor a Tensor.
TypeError: If `grad` is not a Tensor.
TypeError: If `momentum` is neither a Number nor a Tensor.
TypeError: If dtype of `var`, `accum`, `lr`, `grad`, `momentum` is neither float16 nor float32.
ValueError: If `accum` or `grad` doesn't have the same shape as `var`.
ValueError: If the shape size of `lr`, `momentum` is not 0.
Supported Platforms:
``Ascend``
Examples:
>>> class ApplyKerasMomentumNet(nn.Cell):
... def __init__(self, use_locking=False, use_nesterov=False):
... super(ApplyKerasMomentumNet, self).__init__()
... self.apply_keras_momentum = P.ApplyKerasMomentum(use_locking, use_nesterov)
... self.var = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="var")
... self.accum = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="accum")
... def construct(self, lr, grad, momentum):
... out = self.apply_keras_momentum(self.var, self.accum, lr, grad, momentum)
... return out
...
>>> net = ApplyKerasMomentumNet()
>>> lr = Tensor(0.001, mstype.float32)
>>> grad = Tensor(np.array([[0.3, 0.2], [0.4, 0.1]]).astype(np.float32))
>>> momentum = Tensor(0.99, mstype.float32)
>>> output = net(lr, grad, momentum)
>>> print(output)
(Tensor(shape=[2, 2], dtype=Float32, value=
[[ 3.97700012e-01, 5.96800029e-01],
[ 1.98599994e-01, 7.95899987e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
[[ 1.97699994e-01, 2.96800017e-01],
[ 9.86000001e-02, 3.95900011e-01]]))
"""
__mindspore_signature__ = (
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('lr', dtype=sig.sig_dtype.T1),
sig.make_sig('grad', dtype=sig.sig_dtype.T),
sig.make_sig('momentum', dtype=sig.sig_dtype.T2)
)
@prim_attr_register
def __init__(self, use_locking=False, use_nesterov=False):
"""Initialize ApplyKerasMomentum"""
validator.check_value_type("use_locking", use_locking, [bool], self.name)
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name)

View File

@ -1042,6 +1042,18 @@ class SparseApplyRMSPropNet(nn.Cell):
out = self.sparse_apply_r_m_s_prop(self.var, self.ms, self.mom, lr, grad, indices)
return out
class ApplyKerasMomentumNet(nn.Cell):
def __init__(self, use_locking=False, use_nesterov=False):
super(ApplyKerasMomentumNet, self).__init__()
self.apply_keras_momentum = P.ApplyKerasMomentum(use_locking, use_nesterov)
self.var = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="var")
self.accum = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="accum")
def construct(self, lr, grad, momentum):
out = self.apply_keras_momentum(self.var, self.accum, lr, grad, momentum)
return out
test_case_math_ops = [
('BitwiseAnd', {
'block': P.BitwiseAnd(),
@ -2306,6 +2318,12 @@ test_case_nn_ops = [
Tensor(np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32)),
Tensor(np.array([0, 1], dtype=np.int32))],
'skip': ['backward']}),
('ApplyKerasMomentum', {
'block': ApplyKerasMomentumNet(),
'desc_inputs': [Tensor(0.001, mstype.float32),
Tensor(np.array([[0.3, 0.2], [0.4, 0.1]]).astype(np.float32)),
Tensor(0.99, mstype.float32)],
'skip': ['backward']}),
]
test_case_array_ops = [