forked from mindspore-Ecosystem/mindspore
solve broadcast two same shape bprop error
make unsupported shape error info explicit
This commit is contained in:
parent
9bc2ffde54
commit
5962c6efe9
|
@ -673,6 +673,10 @@ def get_bprop_broadcast_to(self):
|
||||||
|
|
||||||
def bprop(x, out, dout):
|
def bprop(x, out, dout):
|
||||||
x_shape = shape_op(x)
|
x_shape = shape_op(x)
|
||||||
|
dout_shape = shape_op(dout)
|
||||||
|
|
||||||
|
if x_shape == dout_shape:
|
||||||
|
return (dout,)
|
||||||
_, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape)
|
_, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape)
|
||||||
reduced_grad = reduce_keep_dim(dout, reduction_axes)
|
reduced_grad = reduce_keep_dim(dout, reduction_axes)
|
||||||
dx = reshape(reduced_grad, x_shape)
|
dx = reshape(reduced_grad, x_shape)
|
||||||
|
|
|
@ -2719,6 +2719,8 @@ class BatchToSpaceND(PrimitiveWithInfer):
|
||||||
class BroadcastTo(PrimitiveWithInfer):
|
class BroadcastTo(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Broadcasts input tensor to a given shape.
|
Broadcasts input tensor to a given shape.
|
||||||
|
Input shape can be broadcast to target shape if for each dimension pair they are either equal or input is one.
|
||||||
|
When input shape is broadcast to target shape, it starts with the trailing dimensions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
shape (tuple): The target shape to broadcast.
|
shape (tuple): The target shape to broadcast.
|
||||||
|
@ -2741,11 +2743,20 @@ class BroadcastTo(PrimitiveWithInfer):
|
||||||
def __init__(self, shape):
|
def __init__(self, shape):
|
||||||
"""Init BroadcastTo"""
|
"""Init BroadcastTo"""
|
||||||
validator.check_value_type("shape", shape, (tuple), self.name)
|
validator.check_value_type("shape", shape, (tuple), self.name)
|
||||||
|
validator.check("shape length", len(shape), "", 0, Rel.GT, self.name)
|
||||||
for i in shape:
|
for i in shape:
|
||||||
validator.check_integer("shape element", i, 0, Rel.GT, self.name)
|
validator.check_integer("shape element", i, 0, Rel.GT, self.name)
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
|
|
||||||
def infer_shape(self, x_shape):
|
def infer_shape(self, x_shape):
|
||||||
|
validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name)
|
||||||
|
|
||||||
|
reversed_x_shape = tuple(reversed(x_shape))
|
||||||
|
reversed_target = tuple(reversed(self.shape))
|
||||||
|
for i, v in enumerate(reversed_x_shape):
|
||||||
|
if v not in (reversed_target[i], 1):
|
||||||
|
raise ValueError(f"Not supported shapes for broadcast, "
|
||||||
|
f"x_shape: {tuple(x_shape)}, target shape {self.shape}.")
|
||||||
return self.shape
|
return self.shape
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype):
|
def infer_dtype(self, x_dtype):
|
||||||
|
|
Loading…
Reference in New Issue