forked from mindspore-Ecosystem/mindspore
!10458 update the description and example of applyftrl operator.
From: @wangshuide2020 Reviewed-by: @liangchenghui,@wuxuejian Signed-off-by: @liangchenghui
This commit is contained in:
commit
1a536233ff
|
@ -5764,7 +5764,12 @@ class ApplyFtrl(PrimitiveWithInfer):
|
|||
Default: -0.5. It must be a float number or a scalar tensor with float16 or float32 data type.
|
||||
|
||||
Outputs:
|
||||
Tensor, represents the updated `var`.
|
||||
There are three outputs for Ascend environment.
|
||||
- **var** (Tensor) - represents the updated `var`.
|
||||
- **accum** (Tensor) - represents the updated `accum`.
|
||||
- **linear** (Tensor) - represents the updated `linear`.
|
||||
There is only one output for GPU environment.
|
||||
- **var** (Tensor) - This value is alwalys zero and the input parameters has been updated in-place.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
@ -5773,8 +5778,8 @@ class ApplyFtrl(PrimitiveWithInfer):
|
|||
>>> import mindspore
|
||||
>>> import mindspore.nn as nn
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Parameter
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore import Parameter, Tensor
|
||||
>>> import mindspore.context as context
|
||||
>>> from mindspore.ops import operations as ops
|
||||
>>> class ApplyFtrlNet(nn.Cell):
|
||||
... def __init__(self):
|
||||
|
@ -5797,7 +5802,9 @@ class ApplyFtrl(PrimitiveWithInfer):
|
|||
>>> net = ApplyFtrlNet()
|
||||
>>> input_x = Tensor(np.random.randint(-4, 4, (2, 2)), mindspore.float32)
|
||||
>>> output = net(input_x)
|
||||
>>> print(output)
|
||||
>>> is_tbe = context.get_context("device_target") == "Ascend"
|
||||
>>> if is_tbe:
|
||||
... print(output)
|
||||
(Tensor(shape=[2, 2], dtype=Float32, value=
|
||||
[[ 4.61418092e-01, 5.30964255e-01],
|
||||
[ 2.68715084e-01, 3.82065028e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
||||
|
@ -5805,6 +5812,16 @@ class ApplyFtrl(PrimitiveWithInfer):
|
|||
[ 1.43758726e+00, 9.89177322e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
||||
[[-1.86994812e+03, -1.64906018e+03],
|
||||
[-3.22187836e+02, -1.20163989e+03]]))
|
||||
>>> else:
|
||||
... print(net.var.asnumpy())
|
||||
[[0.4614181 0.5309642 ]
|
||||
[0.2687151 0.38206503]]
|
||||
... print(net.accum.asnumpy())
|
||||
[[16.423655 9.645894 ]
|
||||
[ 1.4375873 9.891773 ]]
|
||||
... print(net.linear.asnumpy())
|
||||
[[-1869.9479 -1649.0599]
|
||||
[ -322.1879 -1201.6399]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
|
Loading…
Reference in New Issue