forked from mindspore-Ecosystem/mindspore
add general -1 dim behavior for BroadcastTo op
This commit is contained in:
parent
1965ecb9a1
commit
6cead43bdf
|
@ -4551,7 +4551,9 @@ class BroadcastTo(PrimitiveWithInfer):
|
||||||
the target dimension is -1. In case of -1 in target shape, it will be replaced by the input shape's value
|
the target dimension is -1. In case of -1 in target shape, it will be replaced by the input shape's value
|
||||||
in that dimension.
|
in that dimension.
|
||||||
|
|
||||||
When input shape is broadcast to target shape, it starts with the trailing dimensions.
|
When input shape is broadcast to target shape, it starts with the trailing
|
||||||
|
dimensions. If there is a -1 in the target shape, the -1 cannot be in a leading,
|
||||||
|
non-existing dimension.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
shape (tuple): The target shape to broadcast. Can be fully specified, or have -1 in one position
|
shape (tuple): The target shape to broadcast. Can be fully specified, or have -1 in one position
|
||||||
|
@ -4566,9 +4568,8 @@ class BroadcastTo(PrimitiveWithInfer):
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If `shape` is not a tuple.
|
TypeError: If `shape` is not a tuple.
|
||||||
ValueError: Given a shape tuple, if it has several -1; or if the -1 is in an invalid position
|
ValueError: if the target and input shapes are incompatible, or if a -1 in the
|
||||||
such as one that does not have a opposing dimension in an input tensor; or if the target and
|
target shape is in an invalid location.
|
||||||
input shapes are incompatible.
|
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU``
|
``Ascend`` ``GPU``
|
||||||
|
@ -4582,13 +4583,13 @@ class BroadcastTo(PrimitiveWithInfer):
|
||||||
[[1. 2. 3.]
|
[[1. 2. 3.]
|
||||||
[1. 2. 3.]]
|
[1. 2. 3.]]
|
||||||
|
|
||||||
>>> shape = (2, -1)
|
>>> shape = (-1, 2)
|
||||||
>>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32))
|
>>> input_x = Tensor(np.array([[1], [2]]).astype(np.float32))
|
||||||
>>> broadcast_to = ops.BroadcastTo(shape)
|
>>> broadcast_to = ops.BroadcastTo(shape)
|
||||||
>>> output = broadcast_to(input_x)
|
>>> output = broadcast_to(input_x)
|
||||||
>>> print(output)
|
>>> print(output)
|
||||||
[[1. 2. 3.]
|
[[1. 1.]
|
||||||
[1. 2. 3.]]
|
[2. 2.]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
|
@ -4600,35 +4601,30 @@ class BroadcastTo(PrimitiveWithInfer):
|
||||||
validator.check_value_type('target shape index -> ' + str(ix), i, [int], self.name)
|
validator.check_value_type('target shape index -> ' + str(ix), i, [int], self.name)
|
||||||
validator.check("shape element", i, "shape element min limit", -1, Rel.GE, self.name)
|
validator.check("shape element", i, "shape element min limit", -1, Rel.GE, self.name)
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
if -1 in self.shape:
|
|
||||||
undef_dims = self.shape.count(-1)
|
|
||||||
if undef_dims > 1:
|
|
||||||
raise ValueError(f'The shape can only has one -1 at most, but has {undef_dims}.')
|
|
||||||
self.dyn = True
|
|
||||||
else:
|
|
||||||
self.dyn = False
|
|
||||||
|
|
||||||
def infer_shape(self, x_shape):
|
def infer_shape(self, x_shape):
|
||||||
validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name)
|
validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name)
|
||||||
target_shape = list(self.shape)
|
|
||||||
outer_dim_offset = len(target_shape) - len(x_shape)
|
|
||||||
if self.dyn:
|
|
||||||
for i, v in enumerate(target_shape):
|
|
||||||
if v == -1:
|
|
||||||
if i < outer_dim_offset:
|
|
||||||
raise ValueError(f" -1 in init shape is in an incompatible location"
|
|
||||||
f" with given input tensor, -1 index in init shape: {i}"
|
|
||||||
f" but -1 can only be in index {len(x_shape)} onwards for this input.")
|
|
||||||
target_shape[i] = x_shape[i - outer_dim_offset]
|
|
||||||
reversed_x_shape = tuple(reversed(x_shape))
|
reversed_x_shape = tuple(reversed(x_shape))
|
||||||
reversed_target = tuple(reversed(target_shape))
|
reversed_filtered_target = []
|
||||||
for i, v in enumerate(reversed_x_shape):
|
for i, v in enumerate(tuple(reversed(self.shape))):
|
||||||
if v not in (reversed_target[i], 1):
|
if v == -1:
|
||||||
raise ValueError(f"Not supported shapes for broadcast, "
|
if i >= len(reversed_x_shape):
|
||||||
f"x_shape: {tuple(x_shape)}, target shape {target_shape}.")
|
raise ValueError("-1 is not valid in a leading, non-existing dimension")
|
||||||
self.shape = tuple(target_shape)
|
|
||||||
|
reversed_filtered_target.append(reversed_x_shape[i])
|
||||||
|
else:
|
||||||
|
reversed_filtered_target.append(v)
|
||||||
|
|
||||||
|
self.shape = tuple(reversed(reversed_filtered_target))
|
||||||
self.add_prim_attr('shape', self.shape)
|
self.add_prim_attr('shape', self.shape)
|
||||||
return target_shape
|
|
||||||
|
for i, v in enumerate(reversed_x_shape):
|
||||||
|
if v not in (reversed_filtered_target[i], 1):
|
||||||
|
raise ValueError(f"Not supported shapes for broadcast, "
|
||||||
|
f"x_shape: {tuple(x_shape)}, target shape {self.shape}.")
|
||||||
|
|
||||||
|
return self.shape
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype):
|
def infer_dtype(self, x_dtype):
|
||||||
validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
|
validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -54,7 +54,7 @@ def test_broadcast_dyn_init():
|
||||||
"""
|
"""
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||||
|
|
||||||
ms_shape = (-1, 4, 5, 6)
|
ms_shape = (-1, -1, 5, 6)
|
||||||
np_shape = (3, 4, 5, 6)
|
np_shape = (3, 4, 5, 6)
|
||||||
x_np = np.random.rand(3, 1, 5, 1).astype(np.float32)
|
x_np = np.random.rand(3, 1, 5, 1).astype(np.float32)
|
||||||
output = P.BroadcastTo(ms_shape)(Tensor(x_np))
|
output = P.BroadcastTo(ms_shape)(Tensor(x_np))
|
||||||
|
@ -66,7 +66,7 @@ def test_broadcast_dyn_init():
|
||||||
expect = np.broadcast_to(x1_np, np_shape)
|
expect = np.broadcast_to(x1_np, np_shape)
|
||||||
assert np.allclose(output.asnumpy(), expect)
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
ms_shape = (2, 3, -1, 5)
|
ms_shape = (2, 3, -1, -1)
|
||||||
np_shape = (2, 3, 4, 5)
|
np_shape = (2, 3, 4, 5)
|
||||||
x1_np = np.random.rand(4, 5).astype(np.float32)
|
x1_np = np.random.rand(4, 5).astype(np.float32)
|
||||||
output = P.BroadcastTo(ms_shape)(Tensor(x1_np))
|
output = P.BroadcastTo(ms_shape)(Tensor(x1_np))
|
||||||
|
@ -87,3 +87,9 @@ def test_broadcast_dyn_invalid_init():
|
||||||
x_np = np.random.rand(4, 5).astype(np.float32)
|
x_np = np.random.rand(4, 5).astype(np.float32)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
P.BroadcastTo(ms_shape)(Tensor(x_np))
|
P.BroadcastTo(ms_shape)(Tensor(x_np))
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||||
|
ms_shape = (-1, 1, -1, -1)
|
||||||
|
x_np = np.random.rand(4, 5).astype(np.float32)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
P.BroadcastTo(ms_shape)(Tensor(x_np))
|
||||||
|
|
Loading…
Reference in New Issue