forked from mindspore-Ecosystem/mindspore
add LARSUpdate example
This commit is contained in:
parent
7c06d292c8
commit
bb527bc5cf
|
@ -2488,6 +2488,27 @@ class LARSUpdate(PrimitiveWithInfer):
|
|||
|
||||
Outputs:
|
||||
Tensor, representing the new gradient.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.ops import operations as P
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> import mindspore.nn as nn
|
||||
>>> import numpy as np
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.lars = P.LARSUpdate()
|
||||
>>> self.reduce = P.ReduceSum()
|
||||
>>> def construct(self, weight, gradient):
|
||||
>>> w_square_sum = self.reduce(F.square(weight))
|
||||
>>> grad_square_sum = self.reduce(F.square(gradient))
|
||||
>>> grad_t = self.lars(weight, gradient, w_square_sum, grad_square_sum, 0.0, 1.0)
|
||||
>>> return grad_t
|
||||
>>> weight = np.random.random(size=(2, 3)).astype(np.float32)
|
||||
>>> gradient = np.random.random(size=(2, 3)).astype(np.float32)
|
||||
>>> net = Net()
|
||||
>>> ms_output = net(Tensor(weight), Tensor(gradient))
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
|
Loading…
Reference in New Issue