fix bug of sparsegather and broadcastto

This commit is contained in:
simson 2021-04-22 15:05:12 +08:00
parent b5bb831c32
commit 5b72c23972
3 changed files with 28 additions and 3 deletions

View File

@ -48,7 +48,16 @@ abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive,
}
}
}
return std::make_shared<abstract::Shape>(input_x);
auto x_shape_ptr = std::make_shared<abstract::Shape>(input_x);
primitive->AddAttr("shape", MakeValue(input_x));
for (int64_t i = 0; i < (int64_t)x_shape.size(); i++) {
if (input_x[i + outer_dim_offset] != x_shape[i] && x_shape[i] != 1) {
MS_EXCEPTION(ValueError) << "Not support shapes for broadcast, x_shape: "
<< input_args[0]->BuildShape()->ToString()
<< ", target shape: " << x_shape_ptr->ToString();
}
}
return x_shape_ptr;
}
TypePtr BroadcastToInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {

View File

@ -980,11 +980,11 @@ def get_bprop_batch_to_space_nd(self):
def get_bprop_broadcast_to(self):
"""Generate bprop for BroadcastTo"""
reduce_keep_dim = P.ReduceSum(keep_dims=True)
broadcast_shape = self.shape
def bprop(x, out, dout):
x_shape = shape_op(x)
dout_shape = shape_op(dout)
broadcast_shape = shape_op(out)
if x_shape == dout_shape:
return (dout,)

View File

@ -861,7 +861,7 @@ class GatherV2(PrimitiveWithCheck):
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
class SparseGatherV2(Gather):
class SparseGatherV2(PrimitiveWithCheck):
"""
Returns a slice of input tensor based on the specified indices and axis.
@ -893,6 +893,22 @@ class SparseGatherV2(Gather):
[2. 55.]]
"""
@prim_attr_register
def __init__(self):
"""Initialize index_select"""
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
def __check__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name)
axis_v = axis['value']
validator.check_value_type('axis', axis_v, [int], self.name)
rank = len(params['shape'])
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
class Padding(PrimitiveWithInfer):
"""