Adapt ApplyProximalAdagrad and SparseApplyProximalAdagrad

This commit is contained in:
liuxiao 2020-06-17 12:21:46 +08:00
parent b3f91a4f22
commit c7c6f5736b
7 changed files with 180 additions and 95 deletions

View File

@ -75,6 +75,8 @@ static std::map<string, string> tbe_func_adapter_map = {
{"apply_adagrad", "apply_adagrad_d"},
{"apply_adagrad_v2", "apply_adagradv2_d"},
{"sparse_apply_adagrad", "sparse_apply_adagrad_d"},
{"apply_proximal_adagrad", "apply_proximal_adagrad_d"},
{"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"},
{"transpose", "transpose_d"},
{"fill", "fill_d"},
{"unsorted_segment_sum", "unsorted_segment_sum_d"},

View File

@ -391,7 +391,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameBinaryCrossEntropyGrad), ADPT_DESC(BinaryCrossEntropyGrad)},
{string(kNameSparseApplyAdagrad), ADPT_DESC(SparseApplyAdagradD)},
{string(kNameSparseApplyFtrlD), ADPT_DESC(SparseApplyFtrlD)},
{string(kNameApplyProximalAdagrad), ADPT_DESC(ApplyProximalAdagrad)},
{string(kNameApplyProximalAdagrad), ADPT_DESC(ApplyProximalAdagradD)},
{string(kNameAcosh), ADPT_DESC(Acosh)},
{string(kNameAcoshGrad), ADPT_DESC(AcoshGrad)},
{string(kNameFloorMod), ADPT_DESC(FloorMod)},

View File

@ -1170,11 +1170,11 @@ ATTR_MAP(SparseApplyAdagradD) = {{"lr", ATTR_DESC(lr, AnyTraits<float>())},
{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}};
// ApplyProximalAdagrad
INPUT_MAP(ApplyProximalAdagrad) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)},
{4, INPUT_DESC(l1)}, {5, INPUT_DESC(l2)}, {6, INPUT_DESC(grad)}};
ATTR_MAP(ApplyProximalAdagrad) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyProximalAdagrad) = {{0, OUTPUT_DESC(var)}};
// ApplyProximalAdagradD
INPUT_MAP(ApplyProximalAdagradD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)},
{4, INPUT_DESC(l1)}, {5, INPUT_DESC(l2)}, {6, INPUT_DESC(grad)}};
ATTR_MAP(ApplyProximalAdagradD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyProximalAdagradD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}};
// SparseApplyFtrlD
INPUT_MAP(SparseApplyFtrlD) = {{1, INPUT_DESC(var)},

View File

@ -446,8 +446,8 @@ DECLARE_OP_ADAPTER(BinaryCrossEntropyGrad)
DECLARE_OP_USE_OUTPUT(BinaryCrossEntropyGrad)
DECLARE_OP_ADAPTER(SparseApplyAdagradD)
DECLARE_OP_USE_OUTPUT(SparseApplyAdagradD)
DECLARE_OP_ADAPTER(ApplyProximalAdagrad)
DECLARE_OP_USE_OUTPUT(ApplyProximalAdagrad)
DECLARE_OP_ADAPTER(ApplyProximalAdagradD)
DECLARE_OP_USE_OUTPUT(ApplyProximalAdagradD)
DECLARE_OP_ADAPTER(SpaceToDepth)
DECLARE_OP_USE_OUTPUT(SpaceToDepth)
DECLARE_OP_ADAPTER(DepthToSpace)

View File

@ -13,15 +13,15 @@
# limitations under the License.
# ============================================================================
"""ApplyProximalAdagrad op"""
"""ApplyProximalAdagradD op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
apply_proximal_adagrad_op_info = TBERegOp("ApplyProximalAdagrad") \
apply_proximal_adagrad_d_op_info = TBERegOp("ApplyProximalAdagrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("apply_proximal_adagrad.so") \
.binfile_name("apply_proximal_adagrad_d.so") \
.compute_cost(10) \
.kernel_name("apply_proximal_adagrad") \
.kernel_name("apply_proximal_adagrad_d") \
.partial_flag(True) \
.attr("use_locking", "optional", "bool", "true,false", "false") \
.input(0, "var", False, "required", "all") \
@ -31,26 +31,27 @@ apply_proximal_adagrad_op_info = TBERegOp("ApplyProximalAdagrad") \
.input(4, "l2", False, "required", "all") \
.input(5, "grad", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "accum", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \
DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \
DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \
DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \
DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
.get_op_info()
@op_info_register(apply_proximal_adagrad_op_info)
@op_info_register(apply_proximal_adagrad_d_op_info)
def _apply_proximal_adagrad():
"""ApplyProximalAdagrad TBE register"""
"""ApplyProximalAdagradD TBE register"""
return

View File

@ -13,10 +13,10 @@
# limitations under the License.
# ============================================================================
"""SparseApplyProximalAdagrad op"""
"""SparseApplyProximalAdagradD op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
sparse_apply_proximal_adagrad_op_info = TBERegOp("SparseApplyProximalAdagrad") \
sparse_apply_proximal_adagrad_d_op_info = TBERegOp("SparseApplyProximalAdagrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("sparse_apply_proximal_adagrad.so") \
@ -32,70 +32,101 @@ sparse_apply_proximal_adagrad_op_info = TBERegOp("SparseApplyProximalAdagrad") \
.input(5, "grad", False, "required", "all") \
.input(6, "indices", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "accum", False, "required", "all") \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I16_NCHW, DataType.F32_NCHW) \
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I16_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.I16_5HD, DataType.F32_5HD) \
DataType.F32_5HD, DataType.F32_5HD, DataType.I16_5HD, DataType.F32_5HD,
DataType.F32_5HD) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
DataType.F32_NHWC, DataType.F32_NHWC, DataType.I16_NHWC, DataType.F32_NHWC) \
DataType.F32_NHWC, DataType.F32_NHWC, DataType.I16_NHWC, DataType.F32_NHWC,
DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.I16_Default, DataType.F32_Default) \
DataType.F32_Default, DataType.F32_Default, DataType.I16_Default, DataType.F32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ, DataType.F32_FracZ, DataType.I16_FracZ, DataType.F32_FracZ) \
DataType.F32_FracZ, DataType.F32_FracZ, DataType.I16_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW) \
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD) \
DataType.F32_5HD, DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD,
DataType.F32_5HD) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.F32_NHWC) \
DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.F32_NHWC,
DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ, DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ) \
DataType.F32_FracZ, DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW) \
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD) \
DataType.F32_5HD, DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD,
DataType.F32_5HD) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
DataType.F32_NHWC, DataType.F32_NHWC, DataType.I64_NHWC, DataType.F32_NHWC) \
DataType.F32_NHWC, DataType.F32_NHWC, DataType.I64_NHWC, DataType.F32_NHWC,
DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
DataType.F32_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ, DataType.F32_FracZ, DataType.I64_FracZ, DataType.F32_FracZ) \
DataType.F32_FracZ, DataType.F32_FracZ, DataType.I64_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW, DataType.F32_NCHW, DataType.U16_NCHW, DataType.F32_NCHW) \
DataType.F32_NCHW, DataType.F32_NCHW, DataType.U16_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.U16_5HD, DataType.F32_5HD) \
DataType.F32_5HD, DataType.F32_5HD, DataType.U16_5HD, DataType.F32_5HD,
DataType.F32_5HD) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
DataType.F32_NHWC, DataType.F32_NHWC, DataType.U16_NHWC, DataType.F32_NHWC) \
DataType.F32_NHWC, DataType.F32_NHWC, DataType.U16_NHWC, DataType.F32_NHWC,
DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.U16_Default, DataType.F32_Default) \
DataType.F32_Default, DataType.F32_Default, DataType.U16_Default, DataType.F32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ, DataType.F32_FracZ, DataType.U16_FracZ, DataType.F32_FracZ) \
DataType.F32_FracZ, DataType.F32_FracZ, DataType.U16_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW, DataType.F32_NCHW, DataType.U32_NCHW, DataType.F32_NCHW) \
DataType.F32_NCHW, DataType.F32_NCHW, DataType.U32_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.U32_5HD, DataType.F32_5HD) \
DataType.F32_5HD, DataType.F32_5HD, DataType.U32_5HD, DataType.F32_5HD,
DataType.F32_5HD) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
DataType.F32_NHWC, DataType.F32_NHWC, DataType.U32_NHWC, DataType.F32_NHWC) \
DataType.F32_NHWC, DataType.F32_NHWC, DataType.U32_NHWC, DataType.F32_NHWC,
DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.U32_Default, DataType.F32_Default) \
DataType.F32_Default, DataType.F32_Default, DataType.U32_Default, DataType.F32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ, DataType.F32_FracZ, DataType.U32_FracZ, DataType.F32_FracZ) \
DataType.F32_FracZ, DataType.F32_FracZ, DataType.U32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW, DataType.F32_NCHW, DataType.U64_NCHW, DataType.F32_NCHW) \
DataType.F32_NCHW, DataType.F32_NCHW, DataType.U64_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.U64_5HD, DataType.F32_5HD) \
DataType.F32_5HD, DataType.F32_5HD, DataType.U64_5HD, DataType.F32_5HD,
DataType.F32_5HD) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
DataType.F32_NHWC, DataType.F32_NHWC, DataType.U64_NHWC, DataType.F32_NHWC) \
DataType.F32_NHWC, DataType.F32_NHWC, DataType.U64_NHWC, DataType.F32_NHWC,
DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.U64_Default, DataType.F32_Default) \
DataType.F32_Default, DataType.F32_Default, DataType.U64_Default, DataType.F32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ, DataType.F32_FracZ, DataType.U64_FracZ, DataType.F32_FracZ) \
DataType.F32_FracZ, DataType.F32_FracZ, DataType.U64_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ) \
.get_op_info()
@op_info_register(sparse_apply_proximal_adagrad_op_info)
@op_info_register(sparse_apply_proximal_adagrad_d_op_info)
def _sparse_apply_proximal_adagrad():
"""SparseApplyProximalAdagrad TBE register"""
"""SparseApplyProximalAdagradD TBE register"""
return

View File

@ -3142,7 +3142,7 @@ class ApplyAdaMax(PrimitiveWithInfer):
.. math::
\begin{array}{ll} \\
m_{t} = \beta_1 * m_{t-1} + (1 - \beta_1) * g \\
v_{t} = \max(\beta_2 * v{t-1}, \left| g \right|) \\
v_{t} = \max(\beta_2 * v_{t-1}, \left| g \right|) \\
var = var - \frac{l}{1 - \beta_1^t} * \frac{m_{t}}{v_{t} + \epsilon}
\end{array}
@ -3497,37 +3497,61 @@ class ApplyProximalAdagrad(PrimitiveWithInfer):
.. math::
accum += grad * grad
.. math::
prox_v = var - lr * grad * \frac{1}{\sqrt{accum}}
\text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}}
.. math::
var = \frac{sign(prox_v)}{1 + lr * l2} * \max(\left| prox_v \right| - lr * l1, 0)
var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)
Args:
use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False.
Inputs:
- **var** (Tensor) - Variable to be updated.
- **accum** (Tensor) - Accum to be updated. The shape must be the same as `var`'s shape.
- **var** (Parameter) - Variable to be updated. The data type should be float.
- **accum** (Parameter) - Accum to be updated. Must has the same shape and dtype as `var`.
- **lr** (Union[Number, Tensor]): The learning rate value. It should be a scalar tensor or number.
The data type should be float.
- **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero.
It should be a scalar tensor or number.
It should be a scalar tensor or number. The data type should be float.
- **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero.
It should be a scalar tensor or number.
- **grad** (Tensor) - Gradient. The shape must be the same as `var`'s shape.
It should be a scalar tensor or number. The data type should be float.
- **grad** (Tensor) - Gradient. Must has the same shape and dtype as `var`.
Outputs:
Tensor, has the same shape and type as `var`.
Tuple of 2 Tensor, the updated parameters.
- **var** (Tensor) - The same shape and data type as `var`.
- **accum** (Tensor) - The same shape and data type as `accum`.
Examples:
>>> var = Tensor(np.random.random((3, 3)), mindspore.float32)
>>> accum = Tensor(np.random.random((3, 3)), mindspore.float32)
>>> grad = Tensor(np.random.random((3, 3)), mindspore.float32)
>>> lr = 0.01
>>> l1 = 0.0
>>> l2 = 0.0
>>> apply_proximal_ada_grad = P.ApplyProximalAdagrad()
>>> output = apply_proximal_ada_grad(var, accum, lr, l1, l2, grad)
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.apply_proximal_adagrad = P.ApplyProximalAdagrad()
>>> self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
>>> self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
>>> self.lr = 0.01
>>> self.l1 = 0.0
>>> self.l2 = 0.0
>>> def construct(self, grad):
>>> out = self.apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, self.l2, grad)
>>> return out
>>> net = Net()
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32))
>>> output = net(grad)
"""
__mindspore_signature__ = (
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
)
@prim_attr_register
def __init__(self, use_locking=False):
self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad'], outputs=['output'])
@ -3536,7 +3560,7 @@ class ApplyProximalAdagrad(PrimitiveWithInfer):
def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape):
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
validator.check('var shape', var_shape, 'grad shape', grad_shape, Rel.EQ, self.name)
return var_shape
return var_shape, accum_shape
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype):
valid_types = [mstype.float16, mstype.float32]
@ -3544,7 +3568,7 @@ class ApplyProximalAdagrad(PrimitiveWithInfer):
validator.check_tensor_type_same(args, valid_types, self.name)
scalar_args = {"lr": lr_dtype, "l1": l1_dtype, "l2": l2_dtype}
validator.check_scalar_or_tensor_type_same(scalar_args, valid_types, self.name)
return var_dtype
return var_dtype, accum_dtype
class SparseApplyProximalAdagrad(PrimitiveWithInfer):
@ -3555,39 +3579,65 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer):
.. math::
accum += grad * grad
.. math::
prox_v = var - lr * grad * \frac{1}{\sqrt{accum}}
\text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}}
.. math::
var = \frac{sign(prox_v)}{1 + lr * l2} * \max(\left| prox_v \right| - lr * l1, 0)
var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)
Args:
use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False.
Inputs:
- **var** (Tensor) - Variable tensor to be updated.
- **accum** (Tensor) - Variable tensor to be updated. The shape must be the same as `var`'s shape.
- **var** (Parameter) - Variable tensor to be updated. The data type must be float32.
- **accum** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`.
- **lr** (Union[Number, Tensor]): The learning rate value. It should be a scalar tensor or number.
The data type must be float32.
- **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero.
It should be a scalar tensor or number.
It should be a scalar tensor or number. The data type must be float32.
- **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero.
It should be a scalar tensor or number.
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient.
It should be a scalar tensor or number. The data type must be float32.
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. The data type must be float32.
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.
Outputs:
Tensor, has the same shape and type as `var`.
Tuple of 2 Tensor, the updated parameters.
- **var** (Tensor) - The same shape and data type as `var`.
- **accum** (Tensor) - The same shape and data type as `accum`.
Examples:
>>> var = Tensor(np.random.random((3, 3)), mindspore.float32)
>>> accum = Tensor(np.random.random((3, 3)), mindspore.float32)
>>> grad = Tensor(np.random.random((3, 3)), mindspore.float32)
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad()
>>> self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
>>> self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
>>> self.lr = 0.01
>>> self.l1 = 0.0
>>> self.l2 = 0.0
>>> def construct(self, grad, indices):
>>> out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1,
self.l2, grad, indices)
>>> return out
>>> net = Net()
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32))
>>> indices = Tensor(np.ones((3,), np.int32))
>>> lr = 0.01
>>> l1 = 0.0
>>> l2 = 0.0
>>> sparse_apply_proximal_ada_grad = P.SparseApplyProximalAdagrad()
>>> output = sparse_apply_proximal_ada_grad(var, accum, lr, l1, l2, grad, indices)
>>> output = net(grad, indices)
"""
__mindspore_signature__ = (
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
)
@prim_attr_register
def __init__(self, use_locking=False):
self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'],
@ -3595,7 +3645,8 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer):
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape):
return var_shape
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
return var_shape, accum_shape
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype):
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
@ -3605,7 +3656,7 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer):
valid_types = [mstype.int16, mstype.int32, mstype.int64,
mstype.uint16, mstype.uint32, mstype.uint64]
validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name)
return var_dtype
return var_dtype, accum_dtype
class LARSUpdate(PrimitiveWithInfer):
@ -3858,8 +3909,8 @@ class ConfusionMulGrad(PrimitiveWithInfer):
axis (Union[int, tuple[int], list[int]]): The dimensions to reduce.
Default:(), reduce all dimensions. Only constant value is allowed.
keep_dims (bool):
- If true, keep these reduced dimensions and the length is 1.
- If false, don't keep these dimensions. Default:False.
- If True, keep these reduced dimensions and the length is 1.
- If False, don't keep these dimensions. Default:False.
Inputs:
- **input_0** (Tensor) - The input Tensor.