forked from mindspore-Ecosystem/mindspore
fix bug of sparsegather and broadcastto
This commit is contained in:
parent
b5bb831c32
commit
5b72c23972
|
@ -48,7 +48,16 @@ abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive,
|
|||
}
|
||||
}
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(input_x);
|
||||
auto x_shape_ptr = std::make_shared<abstract::Shape>(input_x);
|
||||
primitive->AddAttr("shape", MakeValue(input_x));
|
||||
for (int64_t i = 0; i < (int64_t)x_shape.size(); i++) {
|
||||
if (input_x[i + outer_dim_offset] != x_shape[i] && x_shape[i] != 1) {
|
||||
MS_EXCEPTION(ValueError) << "Not support shapes for broadcast, x_shape: "
|
||||
<< input_args[0]->BuildShape()->ToString()
|
||||
<< ", target shape: " << x_shape_ptr->ToString();
|
||||
}
|
||||
}
|
||||
return x_shape_ptr;
|
||||
}
|
||||
|
||||
TypePtr BroadcastToInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
|
|
@ -980,11 +980,11 @@ def get_bprop_batch_to_space_nd(self):
|
|||
def get_bprop_broadcast_to(self):
|
||||
"""Generate bprop for BroadcastTo"""
|
||||
reduce_keep_dim = P.ReduceSum(keep_dims=True)
|
||||
broadcast_shape = self.shape
|
||||
|
||||
def bprop(x, out, dout):
|
||||
x_shape = shape_op(x)
|
||||
dout_shape = shape_op(dout)
|
||||
broadcast_shape = shape_op(out)
|
||||
|
||||
if x_shape == dout_shape:
|
||||
return (dout,)
|
||||
|
|
|
@ -861,7 +861,7 @@ class GatherV2(PrimitiveWithCheck):
|
|||
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
|
||||
|
||||
|
||||
class SparseGatherV2(Gather):
|
||||
class SparseGatherV2(PrimitiveWithCheck):
|
||||
"""
|
||||
Returns a slice of input tensor based on the specified indices and axis.
|
||||
|
||||
|
@ -893,6 +893,22 @@ class SparseGatherV2(Gather):
|
|||
[2. 55.]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize index_select"""
|
||||
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
||||
|
||||
|
||||
def __check__(self, params, indices, axis):
|
||||
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
||||
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
|
||||
validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name)
|
||||
axis_v = axis['value']
|
||||
validator.check_value_type('axis', axis_v, [int], self.name)
|
||||
rank = len(params['shape'])
|
||||
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
|
||||
|
||||
|
||||
|
||||
class Padding(PrimitiveWithInfer):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue