solve broadcast two same shape bprop error

make unsupported shape error info explicit
This commit is contained in:
zhaozhenlong 2020-06-19 15:13:58 +08:00
parent 9bc2ffde54
commit 5962c6efe9
2 changed files with 15 additions and 0 deletions

View File

@ -673,6 +673,10 @@ def get_bprop_broadcast_to(self):
def bprop(x, out, dout):
x_shape = shape_op(x)
dout_shape = shape_op(dout)
if x_shape == dout_shape:
return (dout,)
_, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape)
reduced_grad = reduce_keep_dim(dout, reduction_axes)
dx = reshape(reduced_grad, x_shape)

View File

@ -2719,6 +2719,8 @@ class BatchToSpaceND(PrimitiveWithInfer):
class BroadcastTo(PrimitiveWithInfer):
"""
Broadcasts input tensor to a given shape.
Input shape can be broadcast to target shape if for each dimension pair they are either equal or input is one.
When input shape is broadcast to target shape, it starts with the trailing dimensions.
Args:
shape (tuple): The target shape to broadcast.
@ -2741,11 +2743,20 @@ class BroadcastTo(PrimitiveWithInfer):
def __init__(self, shape):
"""Init BroadcastTo"""
validator.check_value_type("shape", shape, (tuple), self.name)
validator.check("shape length", len(shape), "", 0, Rel.GT, self.name)
for i in shape:
validator.check_integer("shape element", i, 0, Rel.GT, self.name)
self.shape = shape
def infer_shape(self, x_shape):
validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name)
reversed_x_shape = tuple(reversed(x_shape))
reversed_target = tuple(reversed(self.shape))
for i, v in enumerate(reversed_x_shape):
if v not in (reversed_target[i], 1):
raise ValueError(f"Not supported shapes for broadcast, "
f"x_shape: {tuple(x_shape)}, target shape {self.shape}.")
return self.shape
def infer_dtype(self, x_dtype):