forked from mindspore-Ecosystem/mindspore
Adapt ApplyProximalAdagrad and SparseApplyProximalAdagrad
This commit is contained in:
parent
b3f91a4f22
commit
c7c6f5736b
|
@ -75,6 +75,8 @@ static std::map<string, string> tbe_func_adapter_map = {
|
|||
{"apply_adagrad", "apply_adagrad_d"},
|
||||
{"apply_adagrad_v2", "apply_adagradv2_d"},
|
||||
{"sparse_apply_adagrad", "sparse_apply_adagrad_d"},
|
||||
{"apply_proximal_adagrad", "apply_proximal_adagrad_d"},
|
||||
{"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"},
|
||||
{"transpose", "transpose_d"},
|
||||
{"fill", "fill_d"},
|
||||
{"unsorted_segment_sum", "unsorted_segment_sum_d"},
|
||||
|
|
|
@ -391,7 +391,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
|||
{string(kNameBinaryCrossEntropyGrad), ADPT_DESC(BinaryCrossEntropyGrad)},
|
||||
{string(kNameSparseApplyAdagrad), ADPT_DESC(SparseApplyAdagradD)},
|
||||
{string(kNameSparseApplyFtrlD), ADPT_DESC(SparseApplyFtrlD)},
|
||||
{string(kNameApplyProximalAdagrad), ADPT_DESC(ApplyProximalAdagrad)},
|
||||
{string(kNameApplyProximalAdagrad), ADPT_DESC(ApplyProximalAdagradD)},
|
||||
{string(kNameAcosh), ADPT_DESC(Acosh)},
|
||||
{string(kNameAcoshGrad), ADPT_DESC(AcoshGrad)},
|
||||
{string(kNameFloorMod), ADPT_DESC(FloorMod)},
|
||||
|
|
|
@ -1170,11 +1170,11 @@ ATTR_MAP(SparseApplyAdagradD) = {{"lr", ATTR_DESC(lr, AnyTraits<float>())},
|
|||
{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}};
|
||||
|
||||
// ApplyProximalAdagrad
|
||||
INPUT_MAP(ApplyProximalAdagrad) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)},
|
||||
{4, INPUT_DESC(l1)}, {5, INPUT_DESC(l2)}, {6, INPUT_DESC(grad)}};
|
||||
ATTR_MAP(ApplyProximalAdagrad) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ApplyProximalAdagrad) = {{0, OUTPUT_DESC(var)}};
|
||||
// ApplyProximalAdagradD
|
||||
INPUT_MAP(ApplyProximalAdagradD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)},
|
||||
{4, INPUT_DESC(l1)}, {5, INPUT_DESC(l2)}, {6, INPUT_DESC(grad)}};
|
||||
ATTR_MAP(ApplyProximalAdagradD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ApplyProximalAdagradD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}};
|
||||
|
||||
// SparseApplyFtrlD
|
||||
INPUT_MAP(SparseApplyFtrlD) = {{1, INPUT_DESC(var)},
|
||||
|
|
|
@ -446,8 +446,8 @@ DECLARE_OP_ADAPTER(BinaryCrossEntropyGrad)
|
|||
DECLARE_OP_USE_OUTPUT(BinaryCrossEntropyGrad)
|
||||
DECLARE_OP_ADAPTER(SparseApplyAdagradD)
|
||||
DECLARE_OP_USE_OUTPUT(SparseApplyAdagradD)
|
||||
DECLARE_OP_ADAPTER(ApplyProximalAdagrad)
|
||||
DECLARE_OP_USE_OUTPUT(ApplyProximalAdagrad)
|
||||
DECLARE_OP_ADAPTER(ApplyProximalAdagradD)
|
||||
DECLARE_OP_USE_OUTPUT(ApplyProximalAdagradD)
|
||||
DECLARE_OP_ADAPTER(SpaceToDepth)
|
||||
DECLARE_OP_USE_OUTPUT(SpaceToDepth)
|
||||
DECLARE_OP_ADAPTER(DepthToSpace)
|
||||
|
|
|
@ -13,15 +13,15 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ApplyProximalAdagrad op"""
|
||||
"""ApplyProximalAdagradD op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
apply_proximal_adagrad_op_info = TBERegOp("ApplyProximalAdagrad") \
|
||||
apply_proximal_adagrad_d_op_info = TBERegOp("ApplyProximalAdagrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("apply_proximal_adagrad.so") \
|
||||
.binfile_name("apply_proximal_adagrad_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("apply_proximal_adagrad") \
|
||||
.kernel_name("apply_proximal_adagrad_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("use_locking", "optional", "bool", "true,false", "false") \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
|
@ -31,26 +31,27 @@ apply_proximal_adagrad_op_info = TBERegOp("ApplyProximalAdagrad") \
|
|||
.input(4, "l2", False, "required", "all") \
|
||||
.input(5, "grad", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.output(1, "accum", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||
DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||
DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(apply_proximal_adagrad_op_info)
|
||||
@op_info_register(apply_proximal_adagrad_d_op_info)
|
||||
def _apply_proximal_adagrad():
|
||||
"""ApplyProximalAdagrad TBE register"""
|
||||
"""ApplyProximalAdagradD TBE register"""
|
||||
return
|
||||
|
|
|
@ -13,10 +13,10 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""SparseApplyProximalAdagrad op"""
|
||||
"""SparseApplyProximalAdagradD op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
sparse_apply_proximal_adagrad_op_info = TBERegOp("SparseApplyProximalAdagrad") \
|
||||
sparse_apply_proximal_adagrad_d_op_info = TBERegOp("SparseApplyProximalAdagrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("sparse_apply_proximal_adagrad.so") \
|
||||
|
@ -32,70 +32,101 @@ sparse_apply_proximal_adagrad_op_info = TBERegOp("SparseApplyProximalAdagrad") \
|
|||
.input(5, "grad", False, "required", "all") \
|
||||
.input(6, "indices", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.output(1, "accum", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
|
||||
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I16_NCHW, DataType.F32_NCHW) \
|
||||
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I16_NCHW, DataType.F32_NCHW,
|
||||
DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.I16_5HD, DataType.F32_5HD) \
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.I16_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
|
||||
DataType.F32_NHWC, DataType.F32_NHWC, DataType.I16_NHWC, DataType.F32_NHWC) \
|
||||
DataType.F32_NHWC, DataType.F32_NHWC, DataType.I16_NHWC, DataType.F32_NHWC,
|
||||
DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.I16_Default, DataType.F32_Default) \
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.I16_Default, DataType.F32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ, DataType.F32_FracZ, DataType.I16_FracZ, DataType.F32_FracZ) \
|
||||
DataType.F32_FracZ, DataType.F32_FracZ, DataType.I16_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
|
||||
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW) \
|
||||
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW,
|
||||
DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD) \
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
|
||||
DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.F32_NHWC) \
|
||||
DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.F32_NHWC,
|
||||
DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ, DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ) \
|
||||
DataType.F32_FracZ, DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
|
||||
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW) \
|
||||
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW,
|
||||
DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD) \
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
|
||||
DataType.F32_NHWC, DataType.F32_NHWC, DataType.I64_NHWC, DataType.F32_NHWC) \
|
||||
DataType.F32_NHWC, DataType.F32_NHWC, DataType.I64_NHWC, DataType.F32_NHWC,
|
||||
DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ, DataType.F32_FracZ, DataType.I64_FracZ, DataType.F32_FracZ) \
|
||||
DataType.F32_FracZ, DataType.F32_FracZ, DataType.I64_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
|
||||
DataType.F32_NCHW, DataType.F32_NCHW, DataType.U16_NCHW, DataType.F32_NCHW) \
|
||||
DataType.F32_NCHW, DataType.F32_NCHW, DataType.U16_NCHW, DataType.F32_NCHW,
|
||||
DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.U16_5HD, DataType.F32_5HD) \
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.U16_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
|
||||
DataType.F32_NHWC, DataType.F32_NHWC, DataType.U16_NHWC, DataType.F32_NHWC) \
|
||||
DataType.F32_NHWC, DataType.F32_NHWC, DataType.U16_NHWC, DataType.F32_NHWC,
|
||||
DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.U16_Default, DataType.F32_Default) \
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.U16_Default, DataType.F32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ, DataType.F32_FracZ, DataType.U16_FracZ, DataType.F32_FracZ) \
|
||||
DataType.F32_FracZ, DataType.F32_FracZ, DataType.U16_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
|
||||
DataType.F32_NCHW, DataType.F32_NCHW, DataType.U32_NCHW, DataType.F32_NCHW) \
|
||||
DataType.F32_NCHW, DataType.F32_NCHW, DataType.U32_NCHW, DataType.F32_NCHW,
|
||||
DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.U32_5HD, DataType.F32_5HD) \
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.U32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
|
||||
DataType.F32_NHWC, DataType.F32_NHWC, DataType.U32_NHWC, DataType.F32_NHWC) \
|
||||
DataType.F32_NHWC, DataType.F32_NHWC, DataType.U32_NHWC, DataType.F32_NHWC,
|
||||
DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.U32_Default, DataType.F32_Default) \
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.U32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ, DataType.F32_FracZ, DataType.U32_FracZ, DataType.F32_FracZ) \
|
||||
DataType.F32_FracZ, DataType.F32_FracZ, DataType.U32_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
|
||||
DataType.F32_NCHW, DataType.F32_NCHW, DataType.U64_NCHW, DataType.F32_NCHW) \
|
||||
DataType.F32_NCHW, DataType.F32_NCHW, DataType.U64_NCHW, DataType.F32_NCHW,
|
||||
DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.U64_5HD, DataType.F32_5HD) \
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.U64_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
|
||||
DataType.F32_NHWC, DataType.F32_NHWC, DataType.U64_NHWC, DataType.F32_NHWC) \
|
||||
DataType.F32_NHWC, DataType.F32_NHWC, DataType.U64_NHWC, DataType.F32_NHWC,
|
||||
DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.U64_Default, DataType.F32_Default) \
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.U64_Default, DataType.F32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ, DataType.F32_FracZ, DataType.U64_FracZ, DataType.F32_FracZ) \
|
||||
DataType.F32_FracZ, DataType.F32_FracZ, DataType.U64_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(sparse_apply_proximal_adagrad_op_info)
|
||||
@op_info_register(sparse_apply_proximal_adagrad_d_op_info)
|
||||
def _sparse_apply_proximal_adagrad():
|
||||
"""SparseApplyProximalAdagrad TBE register"""
|
||||
"""SparseApplyProximalAdagradD TBE register"""
|
||||
return
|
||||
|
|
|
@ -3142,7 +3142,7 @@ class ApplyAdaMax(PrimitiveWithInfer):
|
|||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
m_{t} = \beta_1 * m_{t-1} + (1 - \beta_1) * g \\
|
||||
v_{t} = \max(\beta_2 * v{t-1}, \left| g \right|) \\
|
||||
v_{t} = \max(\beta_2 * v_{t-1}, \left| g \right|) \\
|
||||
var = var - \frac{l}{1 - \beta_1^t} * \frac{m_{t}}{v_{t} + \epsilon}
|
||||
\end{array}
|
||||
|
||||
|
@ -3497,37 +3497,61 @@ class ApplyProximalAdagrad(PrimitiveWithInfer):
|
|||
.. math::
|
||||
accum += grad * grad
|
||||
.. math::
|
||||
prox_v = var - lr * grad * \frac{1}{\sqrt{accum}}
|
||||
\text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}}
|
||||
.. math::
|
||||
var = \frac{sign(prox_v)}{1 + lr * l2} * \max(\left| prox_v \right| - lr * l1, 0)
|
||||
var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)
|
||||
|
||||
Args:
|
||||
use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **var** (Tensor) - Variable to be updated.
|
||||
- **accum** (Tensor) - Accum to be updated. The shape must be the same as `var`'s shape.
|
||||
- **var** (Parameter) - Variable to be updated. The data type should be float.
|
||||
- **accum** (Parameter) - Accum to be updated. Must has the same shape and dtype as `var`.
|
||||
- **lr** (Union[Number, Tensor]): The learning rate value. It should be a scalar tensor or number.
|
||||
The data type should be float.
|
||||
- **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero.
|
||||
It should be a scalar tensor or number.
|
||||
It should be a scalar tensor or number. The data type should be float.
|
||||
- **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero.
|
||||
It should be a scalar tensor or number.
|
||||
- **grad** (Tensor) - Gradient. The shape must be the same as `var`'s shape.
|
||||
It should be a scalar tensor or number. The data type should be float.
|
||||
- **grad** (Tensor) - Gradient. Must has the same shape and dtype as `var`.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and type as `var`.
|
||||
Tuple of 2 Tensor, the updated parameters.
|
||||
|
||||
- **var** (Tensor) - The same shape and data type as `var`.
|
||||
- **accum** (Tensor) - The same shape and data type as `accum`.
|
||||
|
||||
Examples:
|
||||
>>> var = Tensor(np.random.random((3, 3)), mindspore.float32)
|
||||
>>> accum = Tensor(np.random.random((3, 3)), mindspore.float32)
|
||||
>>> grad = Tensor(np.random.random((3, 3)), mindspore.float32)
|
||||
>>> lr = 0.01
|
||||
>>> l1 = 0.0
|
||||
>>> l2 = 0.0
|
||||
>>> apply_proximal_ada_grad = P.ApplyProximalAdagrad()
|
||||
>>> output = apply_proximal_ada_grad(var, accum, lr, l1, l2, grad)
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore import Tensor, Parameter
|
||||
>>> from mindspore.ops import operations as P
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.apply_proximal_adagrad = P.ApplyProximalAdagrad()
|
||||
>>> self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
|
||||
>>> self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
|
||||
>>> self.lr = 0.01
|
||||
>>> self.l1 = 0.0
|
||||
>>> self.l2 = 0.0
|
||||
>>> def construct(self, grad):
|
||||
>>> out = self.apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, self.l2, grad)
|
||||
>>> return out
|
||||
>>> net = Net()
|
||||
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32))
|
||||
>>> output = net(grad)
|
||||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False):
|
||||
self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad'], outputs=['output'])
|
||||
|
@ -3536,7 +3560,7 @@ class ApplyProximalAdagrad(PrimitiveWithInfer):
|
|||
def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape):
|
||||
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
|
||||
validator.check('var shape', var_shape, 'grad shape', grad_shape, Rel.EQ, self.name)
|
||||
return var_shape
|
||||
return var_shape, accum_shape
|
||||
|
||||
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype):
|
||||
valid_types = [mstype.float16, mstype.float32]
|
||||
|
@ -3544,7 +3568,7 @@ class ApplyProximalAdagrad(PrimitiveWithInfer):
|
|||
validator.check_tensor_type_same(args, valid_types, self.name)
|
||||
scalar_args = {"lr": lr_dtype, "l1": l1_dtype, "l2": l2_dtype}
|
||||
validator.check_scalar_or_tensor_type_same(scalar_args, valid_types, self.name)
|
||||
return var_dtype
|
||||
return var_dtype, accum_dtype
|
||||
|
||||
|
||||
class SparseApplyProximalAdagrad(PrimitiveWithInfer):
|
||||
|
@ -3555,39 +3579,65 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer):
|
|||
.. math::
|
||||
accum += grad * grad
|
||||
.. math::
|
||||
prox_v = var - lr * grad * \frac{1}{\sqrt{accum}}
|
||||
\text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}}
|
||||
.. math::
|
||||
var = \frac{sign(prox_v)}{1 + lr * l2} * \max(\left| prox_v \right| - lr * l1, 0)
|
||||
var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)
|
||||
|
||||
Args:
|
||||
use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **var** (Tensor) - Variable tensor to be updated.
|
||||
- **accum** (Tensor) - Variable tensor to be updated. The shape must be the same as `var`'s shape.
|
||||
- **var** (Parameter) - Variable tensor to be updated. The data type must be float32.
|
||||
- **accum** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`.
|
||||
- **lr** (Union[Number, Tensor]): The learning rate value. It should be a scalar tensor or number.
|
||||
The data type must be float32.
|
||||
- **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero.
|
||||
It should be a scalar tensor or number.
|
||||
It should be a scalar tensor or number. The data type must be float32.
|
||||
- **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero.
|
||||
It should be a scalar tensor or number.
|
||||
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient.
|
||||
It should be a scalar tensor or number. The data type must be float32.
|
||||
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. The data type must be float32.
|
||||
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and type as `var`.
|
||||
Tuple of 2 Tensor, the updated parameters.
|
||||
|
||||
- **var** (Tensor) - The same shape and data type as `var`.
|
||||
- **accum** (Tensor) - The same shape and data type as `accum`.
|
||||
|
||||
Examples:
|
||||
>>> var = Tensor(np.random.random((3, 3)), mindspore.float32)
|
||||
>>> accum = Tensor(np.random.random((3, 3)), mindspore.float32)
|
||||
>>> grad = Tensor(np.random.random((3, 3)), mindspore.float32)
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore import Tensor, Parameter
|
||||
>>> from mindspore.ops import operations as P
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad()
|
||||
>>> self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
|
||||
>>> self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
|
||||
>>> self.lr = 0.01
|
||||
>>> self.l1 = 0.0
|
||||
>>> self.l2 = 0.0
|
||||
>>> def construct(self, grad, indices):
|
||||
>>> out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1,
|
||||
self.l2, grad, indices)
|
||||
>>> return out
|
||||
>>> net = Net()
|
||||
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32))
|
||||
>>> indices = Tensor(np.ones((3,), np.int32))
|
||||
>>> lr = 0.01
|
||||
>>> l1 = 0.0
|
||||
>>> l2 = 0.0
|
||||
>>> sparse_apply_proximal_ada_grad = P.SparseApplyProximalAdagrad()
|
||||
>>> output = sparse_apply_proximal_ada_grad(var, accum, lr, l1, l2, grad, indices)
|
||||
>>> output = net(grad, indices)
|
||||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False):
|
||||
self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'],
|
||||
|
@ -3595,7 +3645,8 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer):
|
|||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
|
||||
def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape):
|
||||
return var_shape
|
||||
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
|
||||
return var_shape, accum_shape
|
||||
|
||||
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype):
|
||||
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
|
||||
|
@ -3605,7 +3656,7 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer):
|
|||
valid_types = [mstype.int16, mstype.int32, mstype.int64,
|
||||
mstype.uint16, mstype.uint32, mstype.uint64]
|
||||
validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name)
|
||||
return var_dtype
|
||||
return var_dtype, accum_dtype
|
||||
|
||||
|
||||
class LARSUpdate(PrimitiveWithInfer):
|
||||
|
@ -3858,8 +3909,8 @@ class ConfusionMulGrad(PrimitiveWithInfer):
|
|||
axis (Union[int, tuple[int], list[int]]): The dimensions to reduce.
|
||||
Default:(), reduce all dimensions. Only constant value is allowed.
|
||||
keep_dims (bool):
|
||||
- If true, keep these reduced dimensions and the length is 1.
|
||||
- If false, don't keep these dimensions. Default:False.
|
||||
- If True, keep these reduced dimensions and the length is 1.
|
||||
- If False, don't keep these dimensions. Default:False.
|
||||
|
||||
Inputs:
|
||||
- **input_0** (Tensor) - The input Tensor.
|
||||
|
|
Loading…
Reference in New Issue