forked from mindspore-Ecosystem/mindspore
!1924 Add TBE ops ApplyAdaMax\ ApplyAdadelta\ ApplyAdagrad\ ApplyAdagradV2 for VM.
Merge pull request !1924 from liuxiao/ops-for-VM
This commit is contained in:
commit
a8378a3357
|
@ -70,6 +70,10 @@ static std::map<string, string> tbe_func_adapter_map = {
|
|||
{"strided_slice", "strided_slice_d"},
|
||||
{"strided_slice_grad", "strided_slice_grad_d"},
|
||||
{"sparse_apply_ftrl", "sparse_apply_ftrl_d"},
|
||||
{"apply_ada_max", "apply_ada_max_d"},
|
||||
{"apply_adadelta", "apply_adadelta_d"},
|
||||
{"apply_adagrad", "apply_adagrad_d"},
|
||||
{"apply_adagrad_v2", "apply_adagradv2_d"},
|
||||
{"transpose", "transpose_d"},
|
||||
{"fill", "fill_d"},
|
||||
{"unsorted_segment_sum", "unsorted_segment_sum_d"},
|
||||
|
|
|
@ -27,6 +27,10 @@ from .add_n import _add_n_tbe
|
|||
from .apply_ftrl import _apply_ftrl_tbe
|
||||
from .apply_momentum import _apply_momentum_tbe
|
||||
from .apply_adam import _apply_adam_tbe
|
||||
from .apply_ada_max import _apply_ada_max_tbe
|
||||
from .apply_adadelta import _apply_adadelta_tbe
|
||||
from .apply_adagrad import _apply_adagrad_tbe
|
||||
from .apply_adagrad_v2 import _apply_adagrad_v2_tbe
|
||||
from .adam_apply_one import _adam_apply_one_tbe
|
||||
from .assign import _assign_tbe
|
||||
from .assign_add import _assign_add_tbe
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ApplyAdaMaxD op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
apply_ada_max_d_op_info = TBERegOp("ApplyAdaMax") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("apply_ada_max_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("apply_ada_max_d") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
.input(1, "m", False, "required", "all") \
|
||||
.input(2, "v", False, "required", "all") \
|
||||
.input(3, "beta1_power", False, "required", "all") \
|
||||
.input(4, "lr", False, "required", "all") \
|
||||
.input(5, "beta1", False, "required", "all") \
|
||||
.input(6, "beta2", False, "required", "all") \
|
||||
.input(7, "epsilon", False, "required", "all") \
|
||||
.input(8, "grad", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.output(1, "m", False, "required", "all") \
|
||||
.output(2, "v", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(apply_ada_max_d_op_info)
|
||||
def _apply_ada_max_tbe():
|
||||
"""ApplyAdaMaxD TBE register"""
|
||||
return
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ApplyAdadeltaD op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
apply_adadelta_d_op_info = TBERegOp("ApplyAdadelta") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("apply_adadelta_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("apply_adadelta_d") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
.input(1, "accum", False, "required", "all") \
|
||||
.input(2, "accum_update", False, "required", "all") \
|
||||
.input(3, "lr", False, "required", "all") \
|
||||
.input(4, "rho", False, "required", "all") \
|
||||
.input(5, "epsilon", False, "required", "all") \
|
||||
.input(6, "grad", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.output(1, "accum", False, "required", "all") \
|
||||
.output(2, "accum_update", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD,
|
||||
DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ,
|
||||
DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
|
||||
DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
|
||||
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(apply_adadelta_d_op_info)
|
||||
def _apply_adadelta_tbe():
|
||||
"""ApplyAdadeltaD TBE register"""
|
||||
return
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ApplyAdagradD op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
apply_adagrad_d_op_info = TBERegOp("ApplyAdagrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("apply_adagrad_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("apply_adagrad_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("update_slots", "optional", "bool", "true,false", "false") \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
.input(1, "accum", False, "required", "all") \
|
||||
.input(2, "lr", False, "required", "all") \
|
||||
.input(3, "grad", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.output(1, "accum", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD,
|
||||
DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ,
|
||||
DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0,
|
||||
DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0,
|
||||
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(apply_adagrad_d_op_info)
|
||||
def _apply_adagrad_tbe():
|
||||
"""ApplyAdagradD TBE register"""
|
||||
return
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ApplyAdagradV2D op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
apply_adagrad_v2_d_op_info = TBERegOp("ApplyAdagradV2") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("apply_adagradv2_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("apply_adagradv2_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("epsilon", "required", "float", "all") \
|
||||
.attr("update_slots", "optional", "bool", "true,false", "false") \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
.input(1, "accum", False, "required", "all") \
|
||||
.input(2, "lr", False, "required", "all") \
|
||||
.input(3, "grad", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.output(1, "accum", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD,
|
||||
DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ,
|
||||
DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0,
|
||||
DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0,
|
||||
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(apply_adagrad_v2_d_op_info)
|
||||
def _apply_adagrad_v2_tbe():
|
||||
"""ApplyAdagradV2D TBE register"""
|
||||
return
|
|
@ -72,6 +72,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
|
|||
SparseSoftmaxCrossEntropyWithLogits, Tanh,
|
||||
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
|
||||
ApplyProximalAdagrad, SparseApplyProximalAdagrad,
|
||||
ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2,
|
||||
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell)
|
||||
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode,
|
||||
CheckValid, MakeRefKey, Partial, Depend, CheckBprop, ConfusionMatrix)
|
||||
|
@ -282,6 +283,10 @@ __all__ = [
|
|||
"SparseApplyFtrl",
|
||||
"ApplyProximalAdagrad",
|
||||
"SparseApplyProximalAdagrad",
|
||||
"ApplyAdaMax",
|
||||
"ApplyAdadelta",
|
||||
"ApplyAdagrad",
|
||||
"ApplyAdagradV2",
|
||||
"BatchToSpace",
|
||||
"Atan2",
|
||||
"ApplyRMSProp",
|
||||
|
|
|
@ -3075,6 +3075,283 @@ class BinaryCrossEntropy(PrimitiveWithInfer):
|
|||
return x_type
|
||||
|
||||
|
||||
class ApplyAdaMax(PrimitiveWithInfer):
|
||||
r"""
|
||||
Update relevant entries according to the adamax scheme.
|
||||
|
||||
The updating formulas are as follows,
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
m_{t} = \beta_1 * m_{t-1} + (1 - \beta_1) * g \\
|
||||
v_{t} = \max(\beta_2 * v{t-1}, \left| g \right|) \\
|
||||
var = var - \frac{l}{1 - \beta_1^t} * \frac{m_{t}}{v_{t} + \epsilon}
|
||||
\end{array}
|
||||
|
||||
:math:`t` represents updating step while, :math:`m` represents the 1st moment vector, :math:`m_{t-1}`
|
||||
is the last momentent of :math:`m_{t}`, :math:`v` represents the 2nd moment vector, :math:`v_{t-1}`
|
||||
is the last momentent of :math:`v_{t}`, :math:`l` represents scaling factor `lr`,
|
||||
:math:`g` represents `grad`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
|
||||
:math:`beta_1^t` represent `beta1_power`, :math:`var` represents Variable to be updated,
|
||||
:math:`\epsilon` represents `epsilon`.
|
||||
|
||||
Inputs:
|
||||
- **var** (Parameter) - Variable to be updated.
|
||||
- **m** (Parameter) - The 1st moment vector in the updating formula. Has the same shape and type as `var`.
|
||||
- **v** (Parameter) - The 2nd moment vector in the updating formula. Mean square gradients,
|
||||
has the same shape and type as `var`.
|
||||
- **beta1_power** (float) - :math:`beta_1^t` in the updating formula.
|
||||
- **lr** (float) - Learning rate, :math:`l` in the updating formula. Has the same type as `var`.
|
||||
- **beta1** (float) - The exponential decay rate for the 1st moment estimates.
|
||||
- **beta2** (float) - The exponential decay rate for the 2nd moment estimates.
|
||||
- **epsilon** (float) - A small value added for numerical stability.
|
||||
- **grad** (Tensor) - A tensor for gradient. Has the same shape and type as `var`.
|
||||
|
||||
Outputs:
|
||||
Tuple of 3 Tensor, the updated parameters.
|
||||
|
||||
- **var** (Tensor) - The same shape and data type as `var`.
|
||||
- **m** (Tensor) - The same shape and data type as `m`.
|
||||
- **v** (Tensor) - The same shape and data type as `v`.
|
||||
|
||||
Examples:
|
||||
>>> var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
|
||||
>>> m = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="m")
|
||||
>>> v = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="v")
|
||||
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32))
|
||||
>>> beta1_power = 0.9
|
||||
>>> lr = 0.001
|
||||
>>> beta1 = 0.9
|
||||
>>> beta2 = 0.99
|
||||
>>> epsilon = 1e-10
|
||||
>>> apply_ada_max = P.ApplyAdaMax()
|
||||
>>> output = apply_ada_max(var, m, v, beta1_power, lr, beta1, beta2, epsilon, grad)
|
||||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('v', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('beta1_power', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('beta1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('beta2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('epsilon', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init ApplyAdaMax"""
|
||||
|
||||
def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, lr_shape,
|
||||
beta1_shape, beta2_shape, epsilon_shape, grad_shape):
|
||||
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
|
||||
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
|
||||
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
|
||||
return var_shape, m_shape, v_shape
|
||||
|
||||
def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, lr_dtype,
|
||||
beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype):
|
||||
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
|
||||
scalar_args = {"beta1_power": beta1_power_dtype, 'lr': lr_dtype, "beta1": beta1_dtype,
|
||||
"beta2": beta2_dtype, "epsilon": epsilon_dtype}
|
||||
validator.check_scalar_or_tensor_type_same(scalar_args, [mstype.float16, mstype.float32], self.name, True)
|
||||
return var_dtype, m_dtype, v_dtype
|
||||
|
||||
|
||||
class ApplyAdadelta(PrimitiveWithInfer):
|
||||
r"""
|
||||
Update relevant entries according to the adadelta scheme.
|
||||
|
||||
.. math::
|
||||
accum = \rho * accum + (1 - \rho) * grad^2
|
||||
.. math::
|
||||
update = \sqrt{accum_update + \esilon} * \rsqrt{accum + \epsilon} * grad
|
||||
.. math::
|
||||
accum_update = \rho * accum_update + (1 - \rho) * update^2
|
||||
.. math::
|
||||
var -= lr * update
|
||||
|
||||
Inputs:
|
||||
- **var** (Parameter) - Weights to be updated.
|
||||
- **accum** (Parameter) - Accum to be updated, has the same shape and type as `var`.
|
||||
- **accum_update** (Parameter) - Accum_update to be updated, has the same shape and type as `var`.
|
||||
- **lr** (float) - Learning rate, has the same type as `var`.
|
||||
- **rho** (float) - Decay rate.
|
||||
- **epsilon** (float) - A small value added for numerical stability.
|
||||
- **grad** (Tensor) - Gradients, has the same shape and type as `var`.
|
||||
|
||||
Outputs:
|
||||
Tuple of 3 Tensor, the updated parameters.
|
||||
|
||||
- **var** (Tensor) - The same shape and data type as `var`.
|
||||
- **accum** (Tensor) - The same shape and data type as `accum`.
|
||||
- **accum_update** (Tensor) - The same shape and data type as `accum_update`.
|
||||
|
||||
Examples:
|
||||
>>> var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
|
||||
>>> accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
|
||||
>>> accum_update = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum_update")
|
||||
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32))
|
||||
>>> lr = 0.001
|
||||
>>> rho = 0.0
|
||||
>>> epsilon = 1e-6
|
||||
>>> apply_adadelta = P.ApplyAdadelta()
|
||||
>>> output = apply_adadelta(var, accum, accum_update, lr, rho, epsilon, grad)
|
||||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum_update', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('rho', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('epsilon', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init ApplyAdadelta"""
|
||||
|
||||
def infer_shape(self, var_shape, accum_shape, accum_update_shape, lr_shape, rho_shape,
|
||||
epsilon_shape, grad_shape):
|
||||
validator.check("var_shape", var_shape, "accum_shape", accum_shape, Rel.EQ, self.name)
|
||||
validator.check("var_shape", var_shape, "accum_update_shape", accum_update_shape, Rel.EQ, self.name)
|
||||
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
|
||||
return var_shape, accum_shape, accum_update_shape
|
||||
|
||||
def infer_dtype(self, var_dtype, accum_dtype, accum_update_dtype, lr_dtype, rho_shape,
|
||||
epsilon_dtype, grad_dtype):
|
||||
args = {"var": var_dtype, "accum": accum_dtype, "accum_update": accum_update_dtype, "grad": grad_dtype}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
|
||||
scalar_args = {"lr": lr_dtype, "rho": rho_shape, "epsilon": epsilon_dtype}
|
||||
validator.check_scalar_or_tensor_type_same(scalar_args, [mstype.float16, mstype.float32], self.name, True)
|
||||
return var_dtype, accum_dtype, accum_update_dtype
|
||||
|
||||
|
||||
class ApplyAdagrad(PrimitiveWithInfer):
|
||||
r"""
|
||||
Update relevant entries according to the adagrad scheme.
|
||||
|
||||
.. math::
|
||||
accum += grad * grad
|
||||
.. math::
|
||||
var -= lr * grad * \frac{1}{\sqrt{accum}}
|
||||
|
||||
Args:
|
||||
update_slots (bool): If `True`, `accum` will be updated. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **var** (Parameter) - Variable to be updated.
|
||||
- **accum** (Parameter) - Accum to be updated. The shape and dtype should be the same as `var`.
|
||||
- **lr** (float): The learning rate value, has the same type as `var`.
|
||||
- **grad** (Tensor) - A tensor for gradient. The shape and dtype should be the same as `var`.
|
||||
|
||||
Outputs:
|
||||
Tuple of 2 Tensor, the updated parameters.
|
||||
|
||||
- **var** (Tensor) - The same shape and data type as `var`.
|
||||
- **accum** (Tensor) - The same shape and data type as `accum`.
|
||||
|
||||
Examples:
|
||||
>>> var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
|
||||
>>> accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
|
||||
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32))
|
||||
>>> lr = 0.01
|
||||
>>> apply_adagrad = P.ApplyAdagrad()
|
||||
>>> output = apply_adagrad(var, accum, lr, grad)
|
||||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, update_slots=True):
|
||||
validator.check_value_type("update_slots", update_slots, [bool], self.name)
|
||||
|
||||
def infer_shape(self, var_shape, accum_shape, lr_shape, grad_shape):
|
||||
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
|
||||
validator.check('var shape', var_shape, 'grad shape', grad_shape, Rel.EQ, self.name)
|
||||
return var_shape, accum_shape
|
||||
|
||||
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype):
|
||||
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
valid_types = [mstype.float16, mstype.float32]
|
||||
validator.check_scalar_or_tensor_type_same({'lr': lr_dtype}, valid_types, self.name)
|
||||
return var_dtype, accum_dtype
|
||||
|
||||
|
||||
class ApplyAdagradV2(PrimitiveWithInfer):
|
||||
r"""
|
||||
Update relevant entries according to the adagradv2 scheme.
|
||||
|
||||
.. math::
|
||||
accum += grad * grad
|
||||
.. math::
|
||||
var -= lr * grad * \frac{1}{\sqrt{accum} + \epsilon}
|
||||
|
||||
Args:
|
||||
epsilon (float): A small value added for numerical stability.
|
||||
update_slots (bool): If `True`, `accum` will be updated. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **var** (Parameter) - Variable to be updated.
|
||||
- **accum** (Parameter) - Accum to be updated. The shape and dtype should be the same as `var`.
|
||||
- **lr** (float): The learning rate value, has the same type as `var`.
|
||||
- **grad** (Tensor) - A tensor for gradient. The shape and dtype should be the same as `var`.
|
||||
|
||||
Outputs:
|
||||
Tuple of 2 Tensor, the updated parameters.
|
||||
|
||||
- **var** (Tensor) - The same shape and data type as `var`.
|
||||
- **accum** (Tensor) - The same shape and data type as `m`.
|
||||
|
||||
Examples:
|
||||
>>> var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
|
||||
>>> accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
|
||||
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32))
|
||||
>>> lr = 0.01
|
||||
>>> apply_adagrad_v2 = P.ApplyAdagradV2(epsilon=1e-6)
|
||||
>>> output = apply_adagrad_v2(var, accum, lr, grad)
|
||||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, epsilon, update_slots=True):
|
||||
validator.check_value_type("epsilon", epsilon, [float], self.name)
|
||||
validator.check_value_type("update_slots", update_slots, [bool], self.name)
|
||||
|
||||
def infer_shape(self, var_shape, accum_shape, lr_shape, grad_shape):
|
||||
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
|
||||
validator.check('var shape', var_shape, 'grad shape', grad_shape, Rel.EQ, self.name)
|
||||
return var_shape, accum_shape
|
||||
|
||||
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype):
|
||||
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
valid_types = [mstype.float16, mstype.float32]
|
||||
validator.check_scalar_or_tensor_type_same({'lr': lr_dtype}, valid_types, self.name)
|
||||
return var_dtype, accum_dtype
|
||||
|
||||
|
||||
class SparseApplyAdagrad(PrimitiveWithInfer):
|
||||
r"""
|
||||
Update relevant entries according to the adagrad scheme.
|
||||
|
|
|
@ -270,6 +270,67 @@ class ApplyProximalAdagradNet(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
class ApplyAdaMaxNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ApplyAdaMaxNet, self).__init__()
|
||||
self.apply_ada_max = P.ApplyAdaMax()
|
||||
self.beta1_power = 0.9
|
||||
self.lr = 0.001
|
||||
self.beta1 = 0.9
|
||||
self.beta2 = 0.99
|
||||
self.epsilon = 1e-10
|
||||
self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
|
||||
self.m = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="m")
|
||||
self.v = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="v")
|
||||
|
||||
def construct(self, grad):
|
||||
out = self.apply_ada_max(self.var, self.m, self.v, self.beta1_power, self.lr,
|
||||
self.beta1, self.beta2, self.epsilon, grad)
|
||||
return out
|
||||
|
||||
|
||||
class ApplyAdadeltaNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ApplyAdadeltaNet, self).__init__()
|
||||
self.apply_adadelta = P.ApplyAdadelta()
|
||||
self.lr = 0.001
|
||||
self.rho = 0.0
|
||||
self.epsilon = 1e-6
|
||||
self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
|
||||
self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
|
||||
self.accum_update = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum_update")
|
||||
|
||||
def construct(self, grad):
|
||||
out = self.apply_adadelta(self.var, self.accum, self.accum_update, self.lr, self.rho, self.epsilon, grad)
|
||||
return out
|
||||
|
||||
|
||||
class ApplyAdagradNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ApplyAdagradNet, self).__init__()
|
||||
self.apply_adagrad = P.ApplyAdagrad()
|
||||
self.lr = 0.001
|
||||
self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
|
||||
self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
|
||||
|
||||
def construct(self, grad):
|
||||
out = self.apply_adagrad(self.var, self.accum, self.lr, grad)
|
||||
return out
|
||||
|
||||
|
||||
class ApplyAdagradV2Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ApplyAdagradV2Net, self).__init__()
|
||||
self.apply_adagrad_v2 = P.ApplyAdagradV2(epsilon=1e-6)
|
||||
self.lr = 0.001
|
||||
self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
|
||||
self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
|
||||
|
||||
def construct(self, grad):
|
||||
out = self.apply_adagrad_v2(self.var, self.accum, self.lr, grad)
|
||||
return out
|
||||
|
||||
|
||||
class ApplyRMSNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ApplyRMSNet, self).__init__()
|
||||
|
@ -1082,6 +1143,22 @@ test_case_nn_ops = [
|
|||
'block': SparseApplyProximalAdagradNet(),
|
||||
'desc_inputs': [[3, 3], Tensor(np.ones((3,), np.int32))],
|
||||
'skip': ['backward']}),
|
||||
('ApplyAdaMax', {
|
||||
'block': ApplyAdaMaxNet(),
|
||||
'desc_inputs': [[3, 3]],
|
||||
'skip': ['backward']}),
|
||||
('ApplyAdadelta', {
|
||||
'block': ApplyAdadeltaNet(),
|
||||
'desc_inputs': [[3, 3]],
|
||||
'skip': ['backward']}),
|
||||
('ApplyAdagrad', {
|
||||
'block': ApplyAdagradNet(),
|
||||
'desc_inputs': [[3, 3]],
|
||||
'skip': ['backward']}),
|
||||
('ApplyAdagradV2', {
|
||||
'block': ApplyAdagradV2Net(),
|
||||
'desc_inputs': [[3, 3]],
|
||||
'skip': ['backward']}),
|
||||
('Flatten_1', {
|
||||
'block': NetForFlatten(),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3, 4]).astype(np.int32)), Tensor(np.ones([2, 12]).astype(np.int32))],
|
||||
|
|
Loading…
Reference in New Issue