fix the infer of TruncatedNormal and a bug of structure output and a bug of tensorslice ellipsis

This commit is contained in:
buxue 2020-04-23 21:21:15 +08:00
parent ffdad1acd4
commit c8a1a24ce3
6 changed files with 62 additions and 45 deletions

View File

@ -1084,6 +1084,7 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
std::vector<unsigned int> shrink;
auto slice_tuple_eles = slice_tuple->elements();
size_t ellipsis_num = 0;
for (size_t index = 0; index < slice_tuple_size; index++) {
if (slice_tuple_eles[index]->isa<AbstractSlice>()) {
AbstractSlicePtr slice = dyn_cast<AbstractSlice>(slice_tuple_eles[index]);
@ -1118,12 +1119,13 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
<< slice_tuple_eles[index]->ToString();
}
for (size_t index = slice_tuple_size; index < shape_size; index++) {
begin->push_back(0);
end->push_back(shape[index]);
strides->push_back(1);
if (ellipsis_num == 0) {
for (size_t index = slice_tuple_size; index < shape_size; index++) {
begin->push_back(0);
end->push_back(shape[index]);
strides->push_back(1);
}
}
return ConvertBinaryToDecimal(shrink);
}
@ -1199,6 +1201,7 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
if (scalar_ptr->BuildValue()->cast<BoolImmPtr>()->value()) {
return ExpandADim(ret_graph, tensor_node);
}
MS_LOG(EXCEPTION) << "TensorSlice not support the index is False.";
}
shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides);
} else if (args_spec_list[1]->isa<AbstractEllipsis>()) {

View File

@ -319,19 +319,24 @@ void RunGEInitGraph(const py::dict &init_params, const std::string &phase) {
py::object ExtractGeneralCnodeRet(const AbstractBasePtr &cnode_data, const py::tuple &data, size_t *count) {
MS_EXCEPTION_IF_NULL(cnode_data);
if (*count >= data.size()) {
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
<< " less than the number of elements required. ";
}
if (cnode_data->isa<AbstractTensor>()) {
BaseShapePtr shape = cnode_data->BuildShape();
auto shape_act = shape->cast<abstract::ShapePtr>()->shape();
Tensor tensor_exp = py::cast<Tensor>(data[*count]);
if (shape_act != tensor_exp.shape()) {
MS_LOG(EXCEPTION) << "The shape of the tensor returned from GE is not the same as "
"the shape of the tensor derived from ME.";
if (*count >= data.size()) {
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
<< " less than the number of elements required. ";
}
BaseShapePtr shape = cnode_data->BuildShape();
if (!shape->isa<abstract::Shape>()) {
MS_LOG(EXCEPTION) << "The shape of the tensor derived is not Shape, is " << shape->ToString();
}
auto shape_me = shape->cast<abstract::ShapePtr>()->shape();
auto shape_ge = py::cast<Tensor>(data[*count]).shape();
if (shape_ge != shape_me) {
MS_LOG(EXCEPTION) << "The shape of the " << *count << "th tensor returned: " << shape_ge
<< " is not the same as the shape of the tensor derived: " << shape_me;
}
return data[(*count)++];
}
@ -357,11 +362,11 @@ py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data,
return ValuePtrToPyData(GetValueNode(output_node));
}
if (*count >= data.size()) {
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
<< " less than the number of elements required. ";
}
if (output_node->isa<Parameter>()) {
if (*count >= data.size()) {
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
<< " less than the number of elements required. ";
}
return data[(*count)++];
}

View File

@ -147,6 +147,21 @@ def _tensor_getitem_by_number(data, number_index):
return _tensor_slice(data, number_index)
@getitem.register("Tensor", "None")
def _tensor_getitem_by_none(data, index):
"""
Getting item of tensor by None.
Inputs:
data (Tensor): A tensor.
index (None): None.
Outputs:
Tensor, element type is as same as the element type of data.
"""
return _tensor_slice(data, index)
@getitem.register("Tensor", "Slice")
def _tensor_getitem_by_slice(data, slice_index):
"""

View File

@ -633,7 +633,7 @@ class TruncatedNormal(PrimitiveWithInfer):
dtype (:class:`mindspore.dtype`): Data type. Default: mindspore.float32.
Inputs:
- **shape** (Tensor) - Shape of output tensor. The shape is a 1-D tensor, and type is int.
- **shape** (tuple[int]) - Shape of output tensor, is a tuple of positive int.
Outputs:
Tensor, type of output tensor is same as attribute `dtype`.
@ -651,16 +651,10 @@ class TruncatedNormal(PrimitiveWithInfer):
validator.check_typename('dtype', dtype, mstype.number_type)
def __infer__(self, shape):
shape_t = shape['value']
validator.check_subclass("shape", shape['dtype'], mstype.tensor)
shape_n = shape_t.asnumpy()
if shape_n.ndim != 1:
raise ValueError('The rank of input shape must be 1.')
if shape_n.dtype not in (np.int32, np.int64):
raise TypeError('The type of input shape must be int32 or int64.')
for i, item in enumerate(shape_n):
validator.check_integer(f"shape[{i}]", item.item(), 0, Rel.GT)
out = {'shape': tuple(shape_n),
shape_value = shape['value']
for i, value in enumerate(shape_value):
validator.check_integer(f'{i}th value of shape', value, 0, Rel.GT)
out = {'shape': shape_value,
'dtype': mstype.tensor_type(self.dtype),
'value': None}
return out
@ -1648,20 +1642,19 @@ class StridedSlice(PrimitiveWithInfer):
validator.check_type('shrink_axis_mask', shrink_axis_mask, [int])
def __infer__(self, x, begin, end, strides):
begin_shape, end_shape, strides_shape = begin['shape'], end['shape'], strides['shape']
if begin_shape != strides_shape or end_shape != strides_shape:
raise ValueError("The shape of begin, end and strides in 'StridedSlice' must be equal.")
validator.check_const_input("begin", begin['value'])
validator.check_const_input("end", end['value'])
validator.check_const_input("strides", strides['value'])
validator.check_type("begin", begin['value'], [tuple])
validator.check_type("end", end['value'], [tuple])
validator.check_type("strides", strides['value'], [tuple])
x_shape = x['shape']
x_shp_len = len(x_shape)
begin_v, end_v, strides_v = begin['value'], end['value'], strides['value']
validator.check_const_input("begin", begin_v)
validator.check_const_input("end", end_v)
validator.check_const_input("strides", strides_v)
validator.check_type("begin", begin['value'], [tuple])
validator.check_type("end", end['value'], [tuple])
validator.check_type("strides", strides['value'], [tuple])
if len(begin_v) != x_shp_len or len(end_v) != x_shp_len or len(strides_v) != x_shp_len:
raise ValueError(f"The length of begin index{begin_v}, end index{end_v} and strides{strides_v} "
f"must be equal to the dims({x_shp_len}) of input.")
ret_shape = []
append_dimensions = []
shrink_pos = bin(self.shrink_axis_mask)[::-1]

View File

@ -372,7 +372,7 @@ test_case_math_ops = [
'desc_bprop': [[3]]}),
('TruncatedNormal', {
'block': P.TruncatedNormal(),
'desc_const': [Tensor(np.array([1, 2, 3]))],
'desc_const': [[1, 2, 3]],
'desc_inputs': [],
'skip': ['backward'],
'add_fake_input': True}),

View File

@ -52,8 +52,9 @@ class NetWorkSliceEllipsis(Cell):
def construct(self, tensor):
ret0 = tensor[0:4:2, ..., 1] + self.tensor_ret0
ret1 = tensor[...] + self.tensor_ret1
ret2 = tensor[True] + self.tensor_ret2
return ret0, ret1, ret2
ret2 = tensor[None] + self.tensor_ret2
ret3 = tensor[True] + self.tensor_ret2
return ret0, ret1, ret2, ret3
class NetWorkReduceDimension(Cell):
@ -305,7 +306,7 @@ test_cases = [
'block': NetWorkReduceToScalar(),
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
}),
('NetWorkSliceEllipsis', {
('TensorSliceEllipsis', {
'block': NetWorkSliceEllipsis(),
'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))],
}),