!39279 update docs: apply_proximal_gradient_descent for proper dtype

Merge pull request !39279 from Yanzhi_YI/apply_proximal_gradient_descent
This commit is contained in:
i-robot 2022-08-01 09:30:46 +00:00 committed by Gitee
commit fc4791e0b4
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 2 additions and 2 deletions

View File

@ -22,7 +22,7 @@ mindspore.ops.ApplyProximalGradientDescent
- **alpha** (Union[Number, Tensor]) - 比例系数必须为标量。数据类型为float16或float32。
- **l1** (Union[Number, Tensor]) - l1正则化强度必须为标量。数据类型为float16或float32。
- **l2** (Union[Number, Tensor]) - l2正则化强度必须为标量。数据类型为float16或float32。
- **delta** (Tensor) - 梯度Tensorshape和数据类型与 `var` 相同
- **delta** (Tensor) - 梯度Tensor。
输出:
Tensor更新后的 `var`

View File

@ -6656,7 +6656,7 @@ class ApplyProximalGradientDescent(Primitive):
With float32 or float16 data type.
- **l2** (Union[Number, Tensor]) - l2 regularization strength, must be a scalar.
With float32 or float16 data type.
- **delta** (Tensor) - A tensor for the change, has the same shape and data type as `var`.
- **delta** (Tensor) - A tensor for the change.
Outputs:
Tensor, represents the updated `var`.