!49102 fix d2l issues

Merge pull request !49102 from 吕昱峰(Nate.River)/d2l_bug_fix
This commit is contained in:
i-robot 2023-02-20 06:08:43 +00:00 committed by Gitee
commit 0aec3b71e6
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 27 additions and 13 deletions

View File

@ -12,8 +12,8 @@ mindspore.ops.linspace
\end{aligned}
参数:
- **start** (Tensor) - 零维Tensor数据类型必须为float32。区间的起始值。
- **stop** (Tensor) - 零维Tensor数据类型必须为float32。区间的末尾值。
- **start** (Union[Tensor, int, float]) - 零维Tensor数据类型必须为float32。区间的起始值。
- **stop** (Union[Tensor, int, float]) - 零维Tensor数据类型必须为float32。区间的末尾值。
- **num** (Union[Tensor, int]) - 间隔中的包含的数值数量,包括区间端点。必须为正数。
返回:

View File

@ -7,8 +7,8 @@
参数:
- **shape** (tuple) - Tuple: :math:`(N,*)` ,其中 :math:`*` 表示任何数量的附加维度。
- **mean** (Tensor) - 均值μ,指定分布的峰值,数据类型支持[int8, int16, int32, int64, float16, float32]。
- **stddev** (Tensor) - 标准差σ。大于0。数据类型支持[int8, int16, int32, int64, float16, float32]。
- **mean** (Union[Tensor, int, float]) - 均值μ,指定分布的峰值,数据类型支持[int8, int16, int32, int64, float16, float32]。
- **stddev** (Union[Tensor, int, float]) - 标准差σ。大于0。数据类型支持[int8, int16, int32, int64, float16, float32]。
- **seed** (int) - 随机种子。取值须为非负数。默认值None等同于0。
返回:

View File

@ -13,8 +13,8 @@ mindspore.ops.one_hot
参数:
- **indices** (Tensor) - 输入索引shape为 :math:`(X_0, \ldots, X_n)` 的Tensor。数据类型必须为int32或int64。
- **depth** (int) - 输入的Scalar定义one-hot的深度。
- **on_value** (Tensor) - 当 `indices[j] = i`用来填充输出的值。数据类型为float16或float32。
- **off_value** (Tensor) - 当 `indices[j] != i` 时,用来填充输出的值。数据类型与 `on_value` 的相同。
- **on_value** (Union[Tensor, int, float]) - 当 `indices[j] = i`用来填充输出的值。数据类型为float16或float32。
- **off_value** (Union[Tensor, int, float]) - 当 `indices[j] != i` 时,用来填充输出的值。数据类型与 `on_value` 的相同。
- **axis** (int) - 指定one-hot的计算维度。例如如果 `indices` 的shape为 :math:`(N, C)` `axis` 为-1则输出shape为 :math:`(N, C, depth)` ,如果 `axis` 为0则输出shape为 :math:`(depth, N, C)` 。默认值:-1。
返回:

View File

@ -603,10 +603,10 @@ def one_hot(indices, depth, on_value, off_value, axis=-1):
indices(Tensor): A tensor of indices. Tensor of shape :math:`(X_0, \ldots, X_n)`.
Data type must be uint8, int32 or int64.
depth(int): A scalar defining the depth of the one-hot dimension.
on_value(Tensor): A value to fill in output when `indices[j] = i`.
on_value(Union[Tensor, int, float]): A value to fill in output when `indices[j] = i`.
Support uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, float64,
bool, complex64, complex128.
off_value(Tensor): A value to fill in output when `indices[j] != i`.
off_value(Union[Tensor, int, float]): A value to fill in output when `indices[j] != i`.
Has the same data type as `on_value`.
axis(int): Position to insert the value. e.g. If shape of `self` is :math:`(N, C)`, and `axis` is -1,
the output shape will be :math:`(N, C, D)`, If `axis` is 0, the output shape will be :math:`(D, N, C)`.
@ -634,6 +634,10 @@ def one_hot(indices, depth, on_value, off_value, axis=-1):
[0. 1. 0.]
[0. 0. 1.]]
"""
if not isinstance(on_value, Tensor):
on_value = Tensor(on_value)
if not isinstance(off_value, Tensor):
off_value = Tensor(off_value)
onehot = _get_cache_prim(P.OneHot)(axis)
return onehot(indices, depth, on_value, off_value)

View File

@ -3180,8 +3180,10 @@ def linspace(start, stop, num):
\end{aligned}
Args:
start (Tensor): Start value of interval. The tensor data type must be float32 or float64 and with shape of 0-D.
stop (Tensor): Last value of interval. The tensor data type must be float32 or float64 and with shape of 0-D.
start (Union[Tensor, int, float]): Start value of interval. The tensor data type must be float32 or float64
and with shape of 0-D.
stop (Union[Tensor, int, float]): Last value of interval. The tensor data type must be float32 or float64
and with shape of 0-D.
num (Union[Tensor, int]): Number of ticks in the interval, inclusive of start and stop.
Must be positive int number or 0D int32/int64 Tensor.
@ -3206,6 +3208,10 @@ def linspace(start, stop, num):
>>> print(output)
[ 1. 3.25 5.5 7.75 10. ]
"""
if not isinstance(start, Tensor):
start = Tensor(start, mstype.float32)
if not isinstance(stop, Tensor):
stop = Tensor(stop, mstype.float32)
return linspace_(start, stop, num)

View File

@ -634,9 +634,9 @@ def normal(shape, mean, stddev, seed=None):
Args:
shape (tuple): The shape of random tensor to be generated.
The format is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak,
mean (Union[Tensor, int, float]): The mean μ distribution parameter, which specifies the location of the peak,
with data type in [int8, int16, int32, int64, float16, float32].
stddev (Tensor): The deviation σ distribution parameter. It should be greater than 0,
stddev (Union[Tensor, int, float]): The deviation σ distribution parameter. It should be greater than 0,
with data type in [int8, int16, int32, int64, float16, float32].
seed (int): Seed is used as entropy source for the Random number engines to generate pseudo-random numbers.
The value must be non-negative. Default: None, which will be treated as 0.
@ -675,12 +675,16 @@ def normal(shape, mean, stddev, seed=None):
>>> print(result)
(3, 3, 3)
"""
if not isinstance(mean, Tensor):
mean = Tensor(mean)
if not isinstance(stddev, Tensor):
mean = Tensor(stddev)
mean_dtype = F.dtype(mean)
stddev_dtype = F.dtype(stddev)
const_utils.check_type_valid(mean_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal')
const_utils.check_type_valid(stddev_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal')
seed1, seed2 = _get_seed(seed, "normal")
stdnormal = P.StandardNormal(seed1, seed2)
stdnormal = _get_cache_prim(P.StandardNormal)(seed1, seed2)
_check_shape(shape)
random_normal = stdnormal(shape)
value = random_normal * stddev + mean