add LARSUpdate example

This commit is contained in:
Ziyan 2020-04-24 15:47:10 +08:00
parent 7c06d292c8
commit bb527bc5cf
1 changed files with 21 additions and 0 deletions

View File

@ -2488,6 +2488,27 @@ class LARSUpdate(PrimitiveWithInfer):
Outputs:
Tensor, representing the new gradient.
Examples:
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> from mindspore.ops import functional as F
>>> import mindspore.nn as nn
>>> import numpy as np
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.lars = P.LARSUpdate()
>>> self.reduce = P.ReduceSum()
>>> def construct(self, weight, gradient):
>>> w_square_sum = self.reduce(F.square(weight))
>>> grad_square_sum = self.reduce(F.square(gradient))
>>> grad_t = self.lars(weight, gradient, w_square_sum, grad_square_sum, 0.0, 1.0)
>>> return grad_t
>>> weight = np.random.random(size=(2, 3)).astype(np.float32)
>>> gradient = np.random.random(size=(2, 3)).astype(np.float32)
>>> net = Net()
>>> ms_output = net(Tensor(weight), Tensor(gradient))
"""
@prim_attr_register