fix some issues

This commit is contained in:
lilinjie 2022-08-06 17:14:14 +08:00
parent de495ba88b
commit 190e4ffc0e
7 changed files with 18 additions and 12 deletions

View File

@ -18,13 +18,16 @@
所有输入都遵循隐式类型转换规则,以使数据类型一致。如果 `lr``logbase``sign_decay``beta` 是数值型则会自动转换为Tensor数据类型与操作中涉及的Tensor的数据类型一致。如果输入是Tensor并且具有不同的数据类型则低精度数据类型将转换为最高精度的数据类型。
.. note::
目前Ascend平台上暂未开放对float64数据类型的支持。
输入:
- **var** (Parameter) - 要更新的变量。数据类型为float32或float16。如果 `var` 的数据类型为float16则所有输入的数据类型必须与 `var` 相同。shape :math:`(N, *)` ,其中 :math:`*` 表示任意数量的附加维度。
- **var** (Parameter) - 要更新的变量。数据类型为float64、float32或float16。如果 `var` 的数据类型为float16则所有输入的数据类型必须与 `var` 相同。shape :math:`(N, *)` ,其中 :math:`*` 表示任意数量的附加维度。
- **m** (Parameter) - 要更新的变量shape和数据类型与 `var` 相同。
- **lr** (Union[Number, Tensor]) - 学习率应该是Scalar或Tensor数据类型为float32或float16。
- **logbase** (Union[Number, Tensor]) - 应该是Scalar或Tensor数据类型为float32或float16。
- **sign_decay** (Union[Number, Tensor]) - 应该是Scalar或Tensor数据类型为float32或float16。
- **beta** (Union[Number, Tensor]) - 指数衰减率应该是Scalar或Tensor数据类型为float32或float16。
- **lr** (Union[Number, Tensor]) - 学习率应该是Scalar或Tensor数据类型为float64、float32或float16。
- **logbase** (Union[Number, Tensor]) - 应该是Scalar或Tensor数据类型为float64、float32或float16。
- **sign_decay** (Union[Number, Tensor]) - 应该是Scalar或Tensor数据类型为float64、float32或float16。
- **beta** (Union[Number, Tensor]) - 指数衰减率应该是Scalar或Tensor数据类型为float64、float32或float16。
- **grad** (Tensor) - 梯度shape和数据类型与 `var` 相同。
输出:
@ -34,7 +37,7 @@
- **m** (Tensor) - shape和数据类型与 `m` 相同。
异常:
- **TypeError** - 如果 `var``lr``logbase``sign_decay``beta``grad` 的数据类型既不是float16也不是float32
- **TypeError** - 如果 `var``lr``logbase``sign_decay``beta``grad` 的数据类型不是float16、float32或者float64
- **TypeError** - 如果 `lr``logbase``sign_decay``beta` 既不是数值型也不是Tensor。
- **TypeError** - 如果 `grad` 不是Tensor。
- **RuntimeError** - 如果 `lr``logbase``sign_decay``grad` 不支持数据类型转换。

View File

@ -4,5 +4,6 @@ mindspore.ops.Range
.. py:class:: mindspore.ops.Range(maxlen=1000000)
返回从 `start` 开始,步长为 `delta` ,且不超过 `limit` (不包括 `limit` )的序列。
序列的长度不能超过1000000。
更多参考详见 :func:`mindspore.ops.range`

View File

@ -22,6 +22,7 @@ mindspore.ops.squeeze
Tensorshape为 :math:`(x_1, x_2, ..., x_S)`
异常:
- **TypeError** - `input_x` 不是tensor。
- **TypeError** - `axis` 既不是int也不是tuple。
- **TypeError** - `axis` 是tuple其元素并非全部是int。
- **ValueError** - 指定 `axis` 的对应维度不等于1。

View File

@ -81,7 +81,7 @@ class Embedding(Cell):
Raises:
TypeError: If `vocab_size` or `embedding_size` is not an int.
TypeError: If `use_one_hot` is not a bool.
ValueError: If `padding_idx` is an int which not in range [0, `vocab_size`].
ValueError: If `padding_idx` is an int which not in range [0, `vocab_size`).
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -109,7 +109,7 @@ class Embedding(Cell):
self.init_tensor = initializer(embedding_table, [vocab_size, embedding_size])
self.padding_idx = padding_idx
if padding_idx is not None:
self.padding_idx = validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH,
self.padding_idx = validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_LEFT,
"padding_idx", self.cls_name)
if isinstance(self.init_tensor, Tensor) and self.init_tensor.init is not None:
self.init_tensor = self.init_tensor.init_data()

View File

@ -1241,6 +1241,7 @@ def squeeze(input_x, axis=()):
Tensor, the shape of tensor is :math:`(x_1, x_2, ..., x_S)`.
Raises:
TypeError: If `input_x` is not a tensor.
TypeError: If `axis` is neither an int nor tuple.
TypeError: If `axis` is a tuple whose elements are not all int.
ValueError: If the corresponding dimension of the specified axis isn't equal to 1.

View File

@ -1815,13 +1815,13 @@ def grid_sample(input_x, grid, interpolation_mode='bilinear', padding_mode='zero
ValueError: If `padding_mode` is not "zeros", "border", "reflection" or a string value.
Supported Platforms:
``CPU`` ``GPU``
``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.arange(16).reshape((2, 2, 2, 2)).astype(np.float32))
>>> grid = Tensor(np.arange(0.2, 1, 0.1).reshape((2, 2, 1, 2)).astype(np.float32))
>>> output = ops.grid_sample(input_x, grid, interpolation_mode='bilinear', padding_mode='zeros',
align_corners=True)
... align_corners=True)
>>> print(output)
[[[[ 1.9 ]
[ 2.1999998]]

View File

@ -6172,9 +6172,9 @@ class IdentityN(Primitive):
class Range(PrimitiveWithCheck):
r"""
Creates a sequence of numbers that begins at `start` and extends by increments of
`delta` up to but not including `limit`.
`delta` up to but not including `limit`. Length of the created sequence can not exceed 1000000.
Refer to :func:`mindspore.ops.range` for more detailed.
Refer to :func:`mindspore.ops.range` for more details.
Supported Platforms:
``GPU`` ``CPU``