add general -1 dim behavior for BroadcastTo op

This commit is contained in:
Peilin Wang 2021-03-20 01:42:40 -04:00
parent 1965ecb9a1
commit 6cead43bdf
2 changed files with 37 additions and 35 deletions

View File

@ -4551,7 +4551,9 @@ class BroadcastTo(PrimitiveWithInfer):
the target dimension is -1. In case of -1 in target shape, it will be replaced by the input shape's value
in that dimension.
When input shape is broadcast to target shape, it starts with the trailing dimensions.
When input shape is broadcast to target shape, it starts with the trailing
dimensions. If there is a -1 in the target shape, the -1 cannot be in a leading,
non-existing dimension.
Args:
shape (tuple): The target shape to broadcast. Can be fully specified, or have -1 in one position
@ -4566,9 +4568,8 @@ class BroadcastTo(PrimitiveWithInfer):
Raises:
TypeError: If `shape` is not a tuple.
ValueError: Given a shape tuple, if it has several -1; or if the -1 is in an invalid position
such as one that does not have a opposing dimension in an input tensor; or if the target and
input shapes are incompatible.
ValueError: if the target and input shapes are incompatible, or if a -1 in the
target shape is in an invalid location.
Supported Platforms:
``Ascend`` ``GPU``
@ -4582,13 +4583,13 @@ class BroadcastTo(PrimitiveWithInfer):
[[1. 2. 3.]
[1. 2. 3.]]
>>> shape = (2, -1)
>>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32))
>>> shape = (-1, 2)
>>> input_x = Tensor(np.array([[1], [2]]).astype(np.float32))
>>> broadcast_to = ops.BroadcastTo(shape)
>>> output = broadcast_to(input_x)
>>> print(output)
[[1. 2. 3.]
[1. 2. 3.]]
[[1. 1.]
[2. 2.]]
"""
@prim_attr_register
@ -4600,35 +4601,30 @@ class BroadcastTo(PrimitiveWithInfer):
validator.check_value_type('target shape index -> ' + str(ix), i, [int], self.name)
validator.check("shape element", i, "shape element min limit", -1, Rel.GE, self.name)
self.shape = shape
if -1 in self.shape:
undef_dims = self.shape.count(-1)
if undef_dims > 1:
raise ValueError(f'The shape can only has one -1 at most, but has {undef_dims}.')
self.dyn = True
else:
self.dyn = False
def infer_shape(self, x_shape):
validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name)
target_shape = list(self.shape)
outer_dim_offset = len(target_shape) - len(x_shape)
if self.dyn:
for i, v in enumerate(target_shape):
if v == -1:
if i < outer_dim_offset:
raise ValueError(f" -1 in init shape is in an incompatible location"
f" with given input tensor, -1 index in init shape: {i}"
f" but -1 can only be in index {len(x_shape)} onwards for this input.")
target_shape[i] = x_shape[i - outer_dim_offset]
reversed_x_shape = tuple(reversed(x_shape))
reversed_target = tuple(reversed(target_shape))
for i, v in enumerate(reversed_x_shape):
if v not in (reversed_target[i], 1):
raise ValueError(f"Not supported shapes for broadcast, "
f"x_shape: {tuple(x_shape)}, target shape {target_shape}.")
self.shape = tuple(target_shape)
reversed_filtered_target = []
for i, v in enumerate(tuple(reversed(self.shape))):
if v == -1:
if i >= len(reversed_x_shape):
raise ValueError("-1 is not valid in a leading, non-existing dimension")
reversed_filtered_target.append(reversed_x_shape[i])
else:
reversed_filtered_target.append(v)
self.shape = tuple(reversed(reversed_filtered_target))
self.add_prim_attr('shape', self.shape)
return target_shape
for i, v in enumerate(reversed_x_shape):
if v not in (reversed_filtered_target[i], 1):
raise ValueError(f"Not supported shapes for broadcast, "
f"x_shape: {tuple(x_shape)}, target shape {self.shape}.")
return self.shape
def infer_dtype(self, x_dtype):
validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -54,7 +54,7 @@ def test_broadcast_dyn_init():
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
ms_shape = (-1, 4, 5, 6)
ms_shape = (-1, -1, 5, 6)
np_shape = (3, 4, 5, 6)
x_np = np.random.rand(3, 1, 5, 1).astype(np.float32)
output = P.BroadcastTo(ms_shape)(Tensor(x_np))
@ -66,7 +66,7 @@ def test_broadcast_dyn_init():
expect = np.broadcast_to(x1_np, np_shape)
assert np.allclose(output.asnumpy(), expect)
ms_shape = (2, 3, -1, 5)
ms_shape = (2, 3, -1, -1)
np_shape = (2, 3, 4, 5)
x1_np = np.random.rand(4, 5).astype(np.float32)
output = P.BroadcastTo(ms_shape)(Tensor(x1_np))
@ -87,3 +87,9 @@ def test_broadcast_dyn_invalid_init():
x_np = np.random.rand(4, 5).astype(np.float32)
with pytest.raises(ValueError):
P.BroadcastTo(ms_shape)(Tensor(x_np))
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
ms_shape = (-1, 1, -1, -1)
x_np = np.random.rand(4, 5).astype(np.float32)
with pytest.raises(ValueError):
P.BroadcastTo(ms_shape)(Tensor(x_np))