diff --git a/mindspore/core/ops/broadcast_to.cc b/mindspore/core/ops/broadcast_to.cc index 03d71a7ac3a..e0ac4692d6f 100644 --- a/mindspore/core/ops/broadcast_to.cc +++ b/mindspore/core/ops/broadcast_to.cc @@ -48,7 +48,16 @@ abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive, } } } - return std::make_shared(input_x); + auto x_shape_ptr = std::make_shared(input_x); + primitive->AddAttr("shape", MakeValue(input_x)); + for (int64_t i = 0; i < (int64_t)x_shape.size(); i++) { + if (input_x[i + outer_dim_offset] != x_shape[i] && x_shape[i] != 1) { + MS_EXCEPTION(ValueError) << "Not support shapes for broadcast, x_shape: " + << input_args[0]->BuildShape()->ToString() + << ", target shape: " << x_shape_ptr->ToString(); + } + } + return x_shape_ptr; } TypePtr BroadcastToInferType(const PrimitivePtr &prim, const std::vector &input_args) { diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index f0787d25b51..2fe6b50fdb0 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -980,11 +980,11 @@ def get_bprop_batch_to_space_nd(self): def get_bprop_broadcast_to(self): """Generate bprop for BroadcastTo""" reduce_keep_dim = P.ReduceSum(keep_dims=True) - broadcast_shape = self.shape def bprop(x, out, dout): x_shape = shape_op(x) dout_shape = shape_op(dout) + broadcast_shape = shape_op(out) if x_shape == dout_shape: return (dout,) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 735c09079cc..7189ac1f49d 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -861,7 +861,7 @@ class GatherV2(PrimitiveWithCheck): validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) -class SparseGatherV2(Gather): +class SparseGatherV2(PrimitiveWithCheck): """ Returns a slice of input tensor based on the specified indices and axis. @@ -893,6 +893,22 @@ class SparseGatherV2(Gather): [2. 55.]] """ + @prim_attr_register + def __init__(self): + """Initialize index_select""" + self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) + + + def __check__(self, params, indices, axis): + validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) + validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) + validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name) + axis_v = axis['value'] + validator.check_value_type('axis', axis_v, [int], self.name) + rank = len(params['shape']) + validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) + + class Padding(PrimitiveWithInfer): """