!10 Add operator adapting in ME for SparseApplyFtrlD
Merge pull request !10 from zhangzheng/SparseApplyFtrlD
This commit is contained in:
commit
aba38a2401
|
@ -171,6 +171,7 @@ const char kNameAbsGrad[] = "AbsGrad";
|
|||
const char kNameBinaryCrossEntropy[] = "BinaryCrossEntropy";
|
||||
const char kNameBinaryCrossEntropyGrad[] = "BinaryCrossEntropyGrad";
|
||||
const char kNameSparseApplyAdagrad[] = "SparseApplyAdagrad";
|
||||
const char kNameSparseApplyFtrlD[] = "SparseApplyFtrlD";
|
||||
const char kNameSpaceToDepth[] = "SpaceToDepth";
|
||||
const char kNameDepthToSpace[] = "DepthToSpace";
|
||||
const char kNameSign[] = "Sign";
|
||||
|
@ -353,6 +354,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
|||
{string(kNameBinaryCrossEntropy), ADPT_DESC(BinaryCrossEntropy)},
|
||||
{string(kNameBinaryCrossEntropyGrad), ADPT_DESC(BinaryCrossEntropyGrad)},
|
||||
{string(kNameSparseApplyAdagrad), ADPT_DESC(SparseApplyAdagradD)},
|
||||
{string(kNameSparseApplyFtrlD), ADPT_DESC(SparseApplyFtrlD)},
|
||||
{string(kNameSpaceToDepth), ADPT_DESC(SpaceToDepth)},
|
||||
{string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)},
|
||||
{string(kNameSign), ADPT_DESC(Sign)},
|
||||
|
|
|
@ -1120,6 +1120,19 @@ ATTR_MAP(SparseApplyAdagradD) = {{"lr", ATTR_DESC(lr, AnyTraits<float>())},
|
|||
{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}};
|
||||
|
||||
// SparseApplyFtrlD
|
||||
INPUT_MAP(SparseApplyFtrlD) = {{1, INPUT_DESC(var)},
|
||||
{2, INPUT_DESC(accum)},
|
||||
{3, INPUT_DESC(linear)},
|
||||
{4, INPUT_DESC(grad)},
|
||||
{5, INPUT_DESC(indices)}};
|
||||
ATTR_MAP(SparseApplyFtrlD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())},
|
||||
{"lr", ATTR_DESC(lr, AnyTraits<float>())},
|
||||
{"l1", ATTR_DESC(l1, AnyTraits<float>())},
|
||||
{"l2", ATTR_DESC(l2, AnyTraits<float>())},
|
||||
{"lr_power", ATTR_DESC(lr_power, AnyTraits<float>())}};
|
||||
OUTPUT_MAP(SparseApplyFtrlD) = {{0, OUTPUT_DESC(var)}};
|
||||
|
||||
// SpaceToDepth
|
||||
INPUT_MAP(SpaceToDepth) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(SpaceToDepth) = {{"block_size", ATTR_DESC(block_size, AnyTraits<int64_t>())}};
|
||||
|
|
|
@ -435,6 +435,8 @@ DECLARE_OP_ADAPTER(Round)
|
|||
DECLARE_OP_USE_OUTPUT(Round)
|
||||
DECLARE_OP_ADAPTER(ApplyFtrl)
|
||||
DECLARE_OP_USE_OUTPUT(ApplyFtrl)
|
||||
DECLARE_OP_ADAPTER(SparseApplyFtrlD)
|
||||
DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD)
|
||||
#ifdef ENABLE_GE
|
||||
DECLARE_OP_ADAPTER(Print)
|
||||
DECLARE_OP_USE_DYN_INPUT(Print)
|
||||
|
|
|
@ -65,7 +65,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
|
|||
SmoothL1Loss, Softmax,
|
||||
SoftmaxCrossEntropyWithLogits, ROIAlign,
|
||||
SparseSoftmaxCrossEntropyWithLogits, Tanh,
|
||||
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl)
|
||||
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrlD)
|
||||
from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey
|
||||
|
||||
|
||||
|
@ -217,6 +217,7 @@ __all__ = [
|
|||
"Abs",
|
||||
"BinaryCrossEntropy",
|
||||
"SparseApplyAdagrad",
|
||||
"SparseApplyFtrlD",
|
||||
"SpaceToDepth",
|
||||
"DepthToSpace",
|
||||
"Conv2DBackpropInput",
|
||||
|
|
|
@ -2141,6 +2141,80 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
|
|||
return var_type
|
||||
|
||||
|
||||
class SparseApplyFtrlD(PrimitiveWithInfer):
|
||||
r"""
|
||||
Conduct experiment on updating on parameters related to FTRL optimization algorithm.
|
||||
|
||||
.. math ::
|
||||
\text{accum} = \text{grad} * \text{grad}
|
||||
|
||||
.. math ::
|
||||
\text{linear} += \text{grad} + (\text{accum} ^ {\text{-lr_power}} -
|
||||
\frac{\text{accum} ^ \text{-lr_power}}{\text{lr}} * \text{var})
|
||||
|
||||
.. math ::
|
||||
\text{quadratic} = {\text{1.0}/({\text{accum}^\text{lr_power} * \text{lr}}) + 2*\text{l2}
|
||||
|
||||
.. math ::
|
||||
\text{var} = {\text{sign}({linear}) * \text{l1} - \text{linear}})/{ quadratic }
|
||||
if \vert linear \vert > l1 \ else \ 0.0
|
||||
|
||||
Args:
|
||||
lr (float): Learning rate.
|
||||
l1 (float): temp value NO.1.
|
||||
l2 (float): temp value No.2.
|
||||
lr_power (float): temp value used as power number.
|
||||
use_locking (bool): If true, updating the var and accum tensors will be protected. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **var** (Tensor) - Variable to be update. The type must be float32.
|
||||
- **accum** (Tensor) - Accum to be update. The shape must be the same as `var`'s shape,
|
||||
the type must be float32.
|
||||
- **linear** (Tensor) - Linear to be update. The shape must be the same as `var`'s shape,
|
||||
the type must be float32.
|
||||
- **grad** (Tensor) - Gradient. The shape must be the same as `var`'s shape,
|
||||
the type must be float32.
|
||||
- **indices** (Tensor) - A vector of indices into the first dimension of 'var' and 'accum',
|
||||
the shape of `indices` must be the same as `grad` in first dimension, the type must be int32.
|
||||
|
||||
Output:
|
||||
Tensors, has the same shape and type as `var`.
|
||||
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, lr, l1, l2, lr_power, use_locking=False):
|
||||
"""init SparseApplyFtrlD"""
|
||||
self.lr = validator.check_type("lr", lr, [float])
|
||||
self.l1 = validator.check_type("l1", l1, [float])
|
||||
self.l2 = validator.check_type("l2", l2, [float])
|
||||
self.lr_power = validator.check_type("lr_power", lr_power, [float])
|
||||
self.use_locking = validator.check_type("use_locking", use_locaking, [bool])
|
||||
|
||||
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape):
|
||||
validator.check_param_equal('var shape', var_shape, 'accum shape', accum_shape)
|
||||
validator.check_param_equal('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape))
|
||||
validator.check_param_equal('len of var shape', len(var_shape), 'len of linear shape', len(linear_shape))
|
||||
if len(var_shape) > 1:
|
||||
validator.check_param_equal('var_shape', var_shape[1:], 'grad_shape', grad_shape[1:])
|
||||
validator.check_param_equal('var_shape', var_shape[1:], 'linear_shape', linear_shape[1:])
|
||||
validator.check_integer("len of indices shape", len(indices_shape), 1, Rel.EQ)
|
||||
validator.check('the first dimension of grad', grad_shape[0],
|
||||
'the shape of indices', indices_shape[0], Rel.EQ)
|
||||
|
||||
return var_shape
|
||||
|
||||
def infer_dtype(self, var_type, accum_type, linear_type, grad_type, indices_type):
|
||||
validator.check_subclass("var_type", var_type, mstype.tensor)
|
||||
validator.check_subclass("accum_type", accum_type, mstype.tensor)
|
||||
validator.check_subclass("linear_type", linear_type, mstype.tensor)
|
||||
validator.check_subclass("grad_type", grad_type, mstype.tensor)
|
||||
validator.check_subclass("indices_type", indices_type, mstype.tensor)
|
||||
validator.check_subclass('indices_type', indices_type, [mstype.int32])
|
||||
|
||||
return var_type
|
||||
|
||||
|
||||
class LARSUpdate(PrimitiveWithInfer):
|
||||
"""
|
||||
Conduct lars (layer-wise adaptive rate scaling) update on the square sum of gradient.
|
||||
|
@ -2244,4 +2318,4 @@ class ApplyFtrl(PrimitiveWithInfer):
|
|||
validator.check_typename("l1", l1_type,[mstype.float16, mstype.float32])
|
||||
validator.check_typename("l2", l2_type,[mstype.float16, mstype.float32])
|
||||
validator.check_typename("lr_power", lr_power_type,[mstype.float16, mstype.float32])
|
||||
return var_type
|
||||
return var_type
|
||||
|
|
|
@ -749,6 +749,11 @@ test_case_nn_ops = [
|
|||
'desc_inputs': [[3, 3], [3, 3], [3, 3], Tensor(np.ones((3,), np.int32))],
|
||||
'desc_bprop': [3, 3],
|
||||
'skip': ['backward']}),
|
||||
('SparseApplyFtrlD', {
|
||||
'block': P.SparseApplyFtrlD(0.1, 0.1, 0.1, -0.1),
|
||||
'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3], Tensor(2*np.ones((3,), np.int32))],
|
||||
'desc_bprop': [3, 3],
|
||||
'skip': ['backward']}),
|
||||
('Flatten_1', {
|
||||
'block': NetForFlatten(),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3, 4]).astype(np.int32)), Tensor(np.ones([2, 12]).astype(np.int32))],
|
||||
|
|
Loading…
Reference in New Issue