This commit is contained in:
lilinjie 2022-10-21 17:29:26 +08:00
parent b84fdb60e4
commit bce8108587
3 changed files with 20 additions and 17 deletions

View File

@ -11,17 +11,17 @@ mindspore.ops.broadcast_to
- 如果不相等,分以下三种情况:
- 情况一如果目标shape该维的值为-1 则输出shape该维的值为对应输入shape该维的值。比如说输入shape为 :math:`(3, 3)` 目标shape为 :math:`(-1, 3)` 则输出shape为 :math:`(3, 3)`
- 情况一如果目标shape该维的值为-1则输出shape该维的值为对应输入shape该维的值。比如说输入shape为 :math:`(3, 3)` 目标shape为 :math:`(-1, 3)` 则输出shape为 :math:`(3, 3)`
- 情况二如果目标shape该维的值不为-1但是输入shape该维的值为1则输出shape该维的值为目标shape该维的值。比如说输入shape为 :math:` (1, 3)` 目标shape为 :math:`(8, 3)` 则输出shape为 :math:`(8, 3)`
- 情况二如果目标shape该维的值不为-1但是输入shape该维的值为1则输出shape该维的值为目标shape该维的值。比如说输入shape为 :math:`(1, 3)` 目标shape为 :math:`(8, 3)` 则输出shape为 :math:`(8, 3)`
- 情况三如果两个shape对应值不满足以上情况则说明不支持由输入shape广播到目标shape。
- 情况三如果两个shape对应值不满足以上情况则说明不支持由输入shape广播到目标shape。
至此输出shape后面m维就确定好了现在看一下前面 :math:`*` 维,有以下两种情况:
- 如果额外的 :math:`*` 维中不含有-1则输入shape从低维度补充维度使之与目标shape维度一致比如说目标shape为 :math:` (3, 1, 4, 1, 5, 9)` 输入shape为 :math:`(1, 5, 9)` 则输入shape增维变成 :math:`(1, 1, 1, 1, 5, 9)`根据上面提到的情况二可以得出输出shape为 :math:` (3, 1, 4, 1, 5, 9)`
- 如果额外的 :math:`*` 维中不含有-1则输入shape从低维度补充维度使之与目标shape维度一致比如说目标shape为 :math:`(3, 1, 4, 1, 5, 9)` 输入shape为 :math:`(1, 5, 9)` 则输入shape增维变成 :math:`(1, 1, 1, 1, 5, 9)`根据上面提到的情况二可以得出输出shape为 :math:`(3, 1, 4, 1, 5, 9)`
- 如果额外的 :math:`*` 维中含有-1说明此时该-1对应一个不存在的维度不支持广播。比如说目标shape为 :math:` (3, -1, 4, 1, 5, 9)` 输入shape为 :math:`(1, 5, 9)` ,此时不进行增维处理,而是直接报错。
- 如果额外的 :math:`*` 维中含有-1说明此时该-1对应一个不存在的维度不支持广播。比如说目标shape为 :math:`(3, -1, 4, 1, 5, 9)` 输入shape为 :math:`(1, 5, 9)` ,此时不进行增维处理,而是直接报错。
参数:
- **x** (Tensor) - 第一个输入任意维度的Tensor数据类型为float16、float32、int32、int8、uint8、bool。

View File

@ -3251,37 +3251,40 @@ def affine_grid(theta, output_size, align_corners=False):
def broadcast_to(x, shape):
"""
Broadcasts input tensor to a given shape. The dim of input shape must be smaller
than or equal to that of target shape, suppose input shape :math:`(x1, x2, ..., xm)`,
target shape :math:`(*, y_1, y_2, ..., y_m)`. The broadcast rules are as follows:
than or equal to that of target shape. Suppose input shape is :math:`(x1, x2, ..., xm)`,
target shape is :math:`(*, y_1, y_2, ..., y_m)`, where :math:`*` means any additional dimension.
The broadcast rules are as follows:
Compare the value of `x_m` and `y_m`, `x_{m-1}` and `y_{m-1}`, ..., `x_1` and `y_1` consecutively and
decide whether these shapes are broadcastable and what the broadcast result is.
If the value pairs at a specific dim are equal, then that value goes right into that dim of output shape.
With an input shape :math:`(2, 3)`, target shape :math:`(2, 3)` , the inferred outpyt shape is :math:`(2, 3)`.
With an input shape :math:`(2, 3)`, target shape :math:`(2, 3)` , the inferred output shape is :math:`(2, 3)`.
If the value pairs are unequal, there are three cases:
Case 1: Value of target shape is -1, then the value of the output shape is that of the input shape's.
With an input shape :math:`(3, 3)`, target shape :math:`(-1, 3)`, the output shape is :math:`(3, 3)`.
Case 1: If the value of the target shape in the dimension is -1, the value of the
output shape in the dimension is the value of the corresponding input shape in the dimension.
Case 2: Value of target shape is not -1 but the value ot the input shape is 1, then the value of the output shape
is that of the target shape's. With an input shape :math:`(1, 3)`, target
Case 2: If the value of target shape in the dimension is not -1, but the corresponding
value in the input shape is 1, then the corresponding value of the output shape
is that of the target shape. With an input shape :math:`(1, 3)`, target
shape :math:`(8, 3)`, the output shape is :math:`(8, 3)`.
Case 3: All other cases mean that the two shapes are not broadcastable.
Case 3: If the corresponding values of the two shapes do not satisfy the above cases,
it means that broadcasting from the input shape to the target shape is not supported.
So far we got the last m dims of the outshape, now focus on the first :math:`*` dims, there are
two cases:
If the first :math:`*` dims of output shape does not have -1 in it, then fill the input
shape with ones until their length are the same, and then refer to
Case 2 mentioned above to calculate the output shape. With target shape :math:` (3, 1, 4, 1, 5, 9)`,
Case 2 mentioned above to calculate the output shape. With target shape :math:`(3, 1, 4, 1, 5, 9)`,
input shape :math:`(1, 5, 9)`, the filled input shape will be :math:`(1, 1, 1, 1, 5, 9)` and thus the
output shape is :math:` (3, 1, 4, 1, 5, 9)`.
output shape is :math:`(3, 1, 4, 1, 5, 9)`.
If the first :math:`*` dims of output shape have -1 in it, it implies this -1 is conrresponding to
a non-existing dim so they're not broadcastable. With target shape :math:` (3, -1, 4, 1, 5, 9)`,
a non-existing dim so they're not broadcastable. With target shape :math:`(3, -1, 4, 1, 5, 9)`,
input shape :math:`(1, 5, 9)`, instead of operating the dim-filling process first, it raises errors directly.
Args:

View File

@ -807,7 +807,7 @@ class UniformCandidateSampler(PrimitiveWithInfer):
Examples:
>>> sampler = ops.UniformCandidateSampler(1, 3, False, 4, 1)
>>> output1, output2, output3 = sampler(Tensor(np.array([[1], [3], [4], [6], [3]], dtype=np.int32)))
>>> output1, output2, output3 = sampler(Tensor(np.array([[1], [3], [4], [6], [3]], dtype=np.int64)))
>>> print(output1.shape)
(3,)
>>> print(output2.shape)