forked from mindspore-Ecosystem/mindspore
fix the infer of TruncatedNormal and a bug of structure output and a bug of tensorslice ellipsis
This commit is contained in:
parent
ffdad1acd4
commit
c8a1a24ce3
|
@ -1084,6 +1084,7 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
|
||||||
std::vector<unsigned int> shrink;
|
std::vector<unsigned int> shrink;
|
||||||
auto slice_tuple_eles = slice_tuple->elements();
|
auto slice_tuple_eles = slice_tuple->elements();
|
||||||
size_t ellipsis_num = 0;
|
size_t ellipsis_num = 0;
|
||||||
|
|
||||||
for (size_t index = 0; index < slice_tuple_size; index++) {
|
for (size_t index = 0; index < slice_tuple_size; index++) {
|
||||||
if (slice_tuple_eles[index]->isa<AbstractSlice>()) {
|
if (slice_tuple_eles[index]->isa<AbstractSlice>()) {
|
||||||
AbstractSlicePtr slice = dyn_cast<AbstractSlice>(slice_tuple_eles[index]);
|
AbstractSlicePtr slice = dyn_cast<AbstractSlice>(slice_tuple_eles[index]);
|
||||||
|
@ -1118,12 +1119,13 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
|
||||||
<< slice_tuple_eles[index]->ToString();
|
<< slice_tuple_eles[index]->ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t index = slice_tuple_size; index < shape_size; index++) {
|
if (ellipsis_num == 0) {
|
||||||
begin->push_back(0);
|
for (size_t index = slice_tuple_size; index < shape_size; index++) {
|
||||||
end->push_back(shape[index]);
|
begin->push_back(0);
|
||||||
strides->push_back(1);
|
end->push_back(shape[index]);
|
||||||
|
strides->push_back(1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ConvertBinaryToDecimal(shrink);
|
return ConvertBinaryToDecimal(shrink);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1199,6 +1201,7 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
|
||||||
if (scalar_ptr->BuildValue()->cast<BoolImmPtr>()->value()) {
|
if (scalar_ptr->BuildValue()->cast<BoolImmPtr>()->value()) {
|
||||||
return ExpandADim(ret_graph, tensor_node);
|
return ExpandADim(ret_graph, tensor_node);
|
||||||
}
|
}
|
||||||
|
MS_LOG(EXCEPTION) << "TensorSlice not support the index is False.";
|
||||||
}
|
}
|
||||||
shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides);
|
shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides);
|
||||||
} else if (args_spec_list[1]->isa<AbstractEllipsis>()) {
|
} else if (args_spec_list[1]->isa<AbstractEllipsis>()) {
|
||||||
|
|
|
@ -319,19 +319,24 @@ void RunGEInitGraph(const py::dict &init_params, const std::string &phase) {
|
||||||
|
|
||||||
py::object ExtractGeneralCnodeRet(const AbstractBasePtr &cnode_data, const py::tuple &data, size_t *count) {
|
py::object ExtractGeneralCnodeRet(const AbstractBasePtr &cnode_data, const py::tuple &data, size_t *count) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode_data);
|
MS_EXCEPTION_IF_NULL(cnode_data);
|
||||||
if (*count >= data.size()) {
|
|
||||||
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
|
|
||||||
<< " less than the number of elements required. ";
|
|
||||||
}
|
|
||||||
|
|
||||||
if (cnode_data->isa<AbstractTensor>()) {
|
if (cnode_data->isa<AbstractTensor>()) {
|
||||||
BaseShapePtr shape = cnode_data->BuildShape();
|
if (*count >= data.size()) {
|
||||||
auto shape_act = shape->cast<abstract::ShapePtr>()->shape();
|
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
|
||||||
Tensor tensor_exp = py::cast<Tensor>(data[*count]);
|
<< " less than the number of elements required. ";
|
||||||
if (shape_act != tensor_exp.shape()) {
|
|
||||||
MS_LOG(EXCEPTION) << "The shape of the tensor returned from GE is not the same as "
|
|
||||||
"the shape of the tensor derived from ME.";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BaseShapePtr shape = cnode_data->BuildShape();
|
||||||
|
if (!shape->isa<abstract::Shape>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "The shape of the tensor derived is not Shape, is " << shape->ToString();
|
||||||
|
}
|
||||||
|
auto shape_me = shape->cast<abstract::ShapePtr>()->shape();
|
||||||
|
auto shape_ge = py::cast<Tensor>(data[*count]).shape();
|
||||||
|
if (shape_ge != shape_me) {
|
||||||
|
MS_LOG(EXCEPTION) << "The shape of the " << *count << "th tensor returned: " << shape_ge
|
||||||
|
<< " is not the same as the shape of the tensor derived: " << shape_me;
|
||||||
|
}
|
||||||
|
|
||||||
return data[(*count)++];
|
return data[(*count)++];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -357,11 +362,11 @@ py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data,
|
||||||
return ValuePtrToPyData(GetValueNode(output_node));
|
return ValuePtrToPyData(GetValueNode(output_node));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (*count >= data.size()) {
|
|
||||||
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
|
|
||||||
<< " less than the number of elements required. ";
|
|
||||||
}
|
|
||||||
if (output_node->isa<Parameter>()) {
|
if (output_node->isa<Parameter>()) {
|
||||||
|
if (*count >= data.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
|
||||||
|
<< " less than the number of elements required. ";
|
||||||
|
}
|
||||||
return data[(*count)++];
|
return data[(*count)++];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -147,6 +147,21 @@ def _tensor_getitem_by_number(data, number_index):
|
||||||
return _tensor_slice(data, number_index)
|
return _tensor_slice(data, number_index)
|
||||||
|
|
||||||
|
|
||||||
|
@getitem.register("Tensor", "None")
|
||||||
|
def _tensor_getitem_by_none(data, index):
|
||||||
|
"""
|
||||||
|
Getting item of tensor by None.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
data (Tensor): A tensor.
|
||||||
|
index (None): None.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, element type is as same as the element type of data.
|
||||||
|
"""
|
||||||
|
return _tensor_slice(data, index)
|
||||||
|
|
||||||
|
|
||||||
@getitem.register("Tensor", "Slice")
|
@getitem.register("Tensor", "Slice")
|
||||||
def _tensor_getitem_by_slice(data, slice_index):
|
def _tensor_getitem_by_slice(data, slice_index):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -633,7 +633,7 @@ class TruncatedNormal(PrimitiveWithInfer):
|
||||||
dtype (:class:`mindspore.dtype`): Data type. Default: mindspore.float32.
|
dtype (:class:`mindspore.dtype`): Data type. Default: mindspore.float32.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **shape** (Tensor) - Shape of output tensor. The shape is a 1-D tensor, and type is int.
|
- **shape** (tuple[int]) - Shape of output tensor, is a tuple of positive int.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Tensor, type of output tensor is same as attribute `dtype`.
|
Tensor, type of output tensor is same as attribute `dtype`.
|
||||||
|
@ -651,16 +651,10 @@ class TruncatedNormal(PrimitiveWithInfer):
|
||||||
validator.check_typename('dtype', dtype, mstype.number_type)
|
validator.check_typename('dtype', dtype, mstype.number_type)
|
||||||
|
|
||||||
def __infer__(self, shape):
|
def __infer__(self, shape):
|
||||||
shape_t = shape['value']
|
shape_value = shape['value']
|
||||||
validator.check_subclass("shape", shape['dtype'], mstype.tensor)
|
for i, value in enumerate(shape_value):
|
||||||
shape_n = shape_t.asnumpy()
|
validator.check_integer(f'{i}th value of shape', value, 0, Rel.GT)
|
||||||
if shape_n.ndim != 1:
|
out = {'shape': shape_value,
|
||||||
raise ValueError('The rank of input shape must be 1.')
|
|
||||||
if shape_n.dtype not in (np.int32, np.int64):
|
|
||||||
raise TypeError('The type of input shape must be int32 or int64.')
|
|
||||||
for i, item in enumerate(shape_n):
|
|
||||||
validator.check_integer(f"shape[{i}]", item.item(), 0, Rel.GT)
|
|
||||||
out = {'shape': tuple(shape_n),
|
|
||||||
'dtype': mstype.tensor_type(self.dtype),
|
'dtype': mstype.tensor_type(self.dtype),
|
||||||
'value': None}
|
'value': None}
|
||||||
return out
|
return out
|
||||||
|
@ -1648,20 +1642,19 @@ class StridedSlice(PrimitiveWithInfer):
|
||||||
validator.check_type('shrink_axis_mask', shrink_axis_mask, [int])
|
validator.check_type('shrink_axis_mask', shrink_axis_mask, [int])
|
||||||
|
|
||||||
def __infer__(self, x, begin, end, strides):
|
def __infer__(self, x, begin, end, strides):
|
||||||
begin_shape, end_shape, strides_shape = begin['shape'], end['shape'], strides['shape']
|
|
||||||
if begin_shape != strides_shape or end_shape != strides_shape:
|
|
||||||
raise ValueError("The shape of begin, end and strides in 'StridedSlice' must be equal.")
|
|
||||||
|
|
||||||
validator.check_const_input("begin", begin['value'])
|
|
||||||
validator.check_const_input("end", end['value'])
|
|
||||||
validator.check_const_input("strides", strides['value'])
|
|
||||||
validator.check_type("begin", begin['value'], [tuple])
|
|
||||||
validator.check_type("end", end['value'], [tuple])
|
|
||||||
validator.check_type("strides", strides['value'], [tuple])
|
|
||||||
|
|
||||||
x_shape = x['shape']
|
x_shape = x['shape']
|
||||||
x_shp_len = len(x_shape)
|
x_shp_len = len(x_shape)
|
||||||
begin_v, end_v, strides_v = begin['value'], end['value'], strides['value']
|
begin_v, end_v, strides_v = begin['value'], end['value'], strides['value']
|
||||||
|
validator.check_const_input("begin", begin_v)
|
||||||
|
validator.check_const_input("end", end_v)
|
||||||
|
validator.check_const_input("strides", strides_v)
|
||||||
|
validator.check_type("begin", begin['value'], [tuple])
|
||||||
|
validator.check_type("end", end['value'], [tuple])
|
||||||
|
validator.check_type("strides", strides['value'], [tuple])
|
||||||
|
if len(begin_v) != x_shp_len or len(end_v) != x_shp_len or len(strides_v) != x_shp_len:
|
||||||
|
raise ValueError(f"The length of begin index{begin_v}, end index{end_v} and strides{strides_v} "
|
||||||
|
f"must be equal to the dims({x_shp_len}) of input.")
|
||||||
|
|
||||||
ret_shape = []
|
ret_shape = []
|
||||||
append_dimensions = []
|
append_dimensions = []
|
||||||
shrink_pos = bin(self.shrink_axis_mask)[::-1]
|
shrink_pos = bin(self.shrink_axis_mask)[::-1]
|
||||||
|
|
|
@ -372,7 +372,7 @@ test_case_math_ops = [
|
||||||
'desc_bprop': [[3]]}),
|
'desc_bprop': [[3]]}),
|
||||||
('TruncatedNormal', {
|
('TruncatedNormal', {
|
||||||
'block': P.TruncatedNormal(),
|
'block': P.TruncatedNormal(),
|
||||||
'desc_const': [Tensor(np.array([1, 2, 3]))],
|
'desc_const': [[1, 2, 3]],
|
||||||
'desc_inputs': [],
|
'desc_inputs': [],
|
||||||
'skip': ['backward'],
|
'skip': ['backward'],
|
||||||
'add_fake_input': True}),
|
'add_fake_input': True}),
|
||||||
|
|
|
@ -52,8 +52,9 @@ class NetWorkSliceEllipsis(Cell):
|
||||||
def construct(self, tensor):
|
def construct(self, tensor):
|
||||||
ret0 = tensor[0:4:2, ..., 1] + self.tensor_ret0
|
ret0 = tensor[0:4:2, ..., 1] + self.tensor_ret0
|
||||||
ret1 = tensor[...] + self.tensor_ret1
|
ret1 = tensor[...] + self.tensor_ret1
|
||||||
ret2 = tensor[True] + self.tensor_ret2
|
ret2 = tensor[None] + self.tensor_ret2
|
||||||
return ret0, ret1, ret2
|
ret3 = tensor[True] + self.tensor_ret2
|
||||||
|
return ret0, ret1, ret2, ret3
|
||||||
|
|
||||||
|
|
||||||
class NetWorkReduceDimension(Cell):
|
class NetWorkReduceDimension(Cell):
|
||||||
|
@ -305,7 +306,7 @@ test_cases = [
|
||||||
'block': NetWorkReduceToScalar(),
|
'block': NetWorkReduceToScalar(),
|
||||||
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
|
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
|
||||||
}),
|
}),
|
||||||
('NetWorkSliceEllipsis', {
|
('TensorSliceEllipsis', {
|
||||||
'block': NetWorkSliceEllipsis(),
|
'block': NetWorkSliceEllipsis(),
|
||||||
'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))],
|
'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))],
|
||||||
}),
|
}),
|
||||||
|
|
Loading…
Reference in New Issue