forked from mindspore-Ecosystem/mindspore
fix the infer of TruncatedNormal and a bug of structure output and a bug of tensorslice ellipsis
This commit is contained in:
parent
ffdad1acd4
commit
c8a1a24ce3
|
@ -1084,6 +1084,7 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
|
|||
std::vector<unsigned int> shrink;
|
||||
auto slice_tuple_eles = slice_tuple->elements();
|
||||
size_t ellipsis_num = 0;
|
||||
|
||||
for (size_t index = 0; index < slice_tuple_size; index++) {
|
||||
if (slice_tuple_eles[index]->isa<AbstractSlice>()) {
|
||||
AbstractSlicePtr slice = dyn_cast<AbstractSlice>(slice_tuple_eles[index]);
|
||||
|
@ -1118,12 +1119,13 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
|
|||
<< slice_tuple_eles[index]->ToString();
|
||||
}
|
||||
|
||||
for (size_t index = slice_tuple_size; index < shape_size; index++) {
|
||||
begin->push_back(0);
|
||||
end->push_back(shape[index]);
|
||||
strides->push_back(1);
|
||||
if (ellipsis_num == 0) {
|
||||
for (size_t index = slice_tuple_size; index < shape_size; index++) {
|
||||
begin->push_back(0);
|
||||
end->push_back(shape[index]);
|
||||
strides->push_back(1);
|
||||
}
|
||||
}
|
||||
|
||||
return ConvertBinaryToDecimal(shrink);
|
||||
}
|
||||
|
||||
|
@ -1199,6 +1201,7 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
|
|||
if (scalar_ptr->BuildValue()->cast<BoolImmPtr>()->value()) {
|
||||
return ExpandADim(ret_graph, tensor_node);
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "TensorSlice not support the index is False.";
|
||||
}
|
||||
shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides);
|
||||
} else if (args_spec_list[1]->isa<AbstractEllipsis>()) {
|
||||
|
|
|
@ -319,19 +319,24 @@ void RunGEInitGraph(const py::dict &init_params, const std::string &phase) {
|
|||
|
||||
py::object ExtractGeneralCnodeRet(const AbstractBasePtr &cnode_data, const py::tuple &data, size_t *count) {
|
||||
MS_EXCEPTION_IF_NULL(cnode_data);
|
||||
if (*count >= data.size()) {
|
||||
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
|
||||
<< " less than the number of elements required. ";
|
||||
}
|
||||
|
||||
if (cnode_data->isa<AbstractTensor>()) {
|
||||
BaseShapePtr shape = cnode_data->BuildShape();
|
||||
auto shape_act = shape->cast<abstract::ShapePtr>()->shape();
|
||||
Tensor tensor_exp = py::cast<Tensor>(data[*count]);
|
||||
if (shape_act != tensor_exp.shape()) {
|
||||
MS_LOG(EXCEPTION) << "The shape of the tensor returned from GE is not the same as "
|
||||
"the shape of the tensor derived from ME.";
|
||||
if (*count >= data.size()) {
|
||||
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
|
||||
<< " less than the number of elements required. ";
|
||||
}
|
||||
|
||||
BaseShapePtr shape = cnode_data->BuildShape();
|
||||
if (!shape->isa<abstract::Shape>()) {
|
||||
MS_LOG(EXCEPTION) << "The shape of the tensor derived is not Shape, is " << shape->ToString();
|
||||
}
|
||||
auto shape_me = shape->cast<abstract::ShapePtr>()->shape();
|
||||
auto shape_ge = py::cast<Tensor>(data[*count]).shape();
|
||||
if (shape_ge != shape_me) {
|
||||
MS_LOG(EXCEPTION) << "The shape of the " << *count << "th tensor returned: " << shape_ge
|
||||
<< " is not the same as the shape of the tensor derived: " << shape_me;
|
||||
}
|
||||
|
||||
return data[(*count)++];
|
||||
}
|
||||
|
||||
|
@ -357,11 +362,11 @@ py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data,
|
|||
return ValuePtrToPyData(GetValueNode(output_node));
|
||||
}
|
||||
|
||||
if (*count >= data.size()) {
|
||||
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
|
||||
<< " less than the number of elements required. ";
|
||||
}
|
||||
if (output_node->isa<Parameter>()) {
|
||||
if (*count >= data.size()) {
|
||||
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
|
||||
<< " less than the number of elements required. ";
|
||||
}
|
||||
return data[(*count)++];
|
||||
}
|
||||
|
||||
|
|
|
@ -147,6 +147,21 @@ def _tensor_getitem_by_number(data, number_index):
|
|||
return _tensor_slice(data, number_index)
|
||||
|
||||
|
||||
@getitem.register("Tensor", "None")
|
||||
def _tensor_getitem_by_none(data, index):
|
||||
"""
|
||||
Getting item of tensor by None.
|
||||
|
||||
Inputs:
|
||||
data (Tensor): A tensor.
|
||||
index (None): None.
|
||||
|
||||
Outputs:
|
||||
Tensor, element type is as same as the element type of data.
|
||||
"""
|
||||
return _tensor_slice(data, index)
|
||||
|
||||
|
||||
@getitem.register("Tensor", "Slice")
|
||||
def _tensor_getitem_by_slice(data, slice_index):
|
||||
"""
|
||||
|
|
|
@ -633,7 +633,7 @@ class TruncatedNormal(PrimitiveWithInfer):
|
|||
dtype (:class:`mindspore.dtype`): Data type. Default: mindspore.float32.
|
||||
|
||||
Inputs:
|
||||
- **shape** (Tensor) - Shape of output tensor. The shape is a 1-D tensor, and type is int.
|
||||
- **shape** (tuple[int]) - Shape of output tensor, is a tuple of positive int.
|
||||
|
||||
Outputs:
|
||||
Tensor, type of output tensor is same as attribute `dtype`.
|
||||
|
@ -651,16 +651,10 @@ class TruncatedNormal(PrimitiveWithInfer):
|
|||
validator.check_typename('dtype', dtype, mstype.number_type)
|
||||
|
||||
def __infer__(self, shape):
|
||||
shape_t = shape['value']
|
||||
validator.check_subclass("shape", shape['dtype'], mstype.tensor)
|
||||
shape_n = shape_t.asnumpy()
|
||||
if shape_n.ndim != 1:
|
||||
raise ValueError('The rank of input shape must be 1.')
|
||||
if shape_n.dtype not in (np.int32, np.int64):
|
||||
raise TypeError('The type of input shape must be int32 or int64.')
|
||||
for i, item in enumerate(shape_n):
|
||||
validator.check_integer(f"shape[{i}]", item.item(), 0, Rel.GT)
|
||||
out = {'shape': tuple(shape_n),
|
||||
shape_value = shape['value']
|
||||
for i, value in enumerate(shape_value):
|
||||
validator.check_integer(f'{i}th value of shape', value, 0, Rel.GT)
|
||||
out = {'shape': shape_value,
|
||||
'dtype': mstype.tensor_type(self.dtype),
|
||||
'value': None}
|
||||
return out
|
||||
|
@ -1648,20 +1642,19 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
validator.check_type('shrink_axis_mask', shrink_axis_mask, [int])
|
||||
|
||||
def __infer__(self, x, begin, end, strides):
|
||||
begin_shape, end_shape, strides_shape = begin['shape'], end['shape'], strides['shape']
|
||||
if begin_shape != strides_shape or end_shape != strides_shape:
|
||||
raise ValueError("The shape of begin, end and strides in 'StridedSlice' must be equal.")
|
||||
|
||||
validator.check_const_input("begin", begin['value'])
|
||||
validator.check_const_input("end", end['value'])
|
||||
validator.check_const_input("strides", strides['value'])
|
||||
validator.check_type("begin", begin['value'], [tuple])
|
||||
validator.check_type("end", end['value'], [tuple])
|
||||
validator.check_type("strides", strides['value'], [tuple])
|
||||
|
||||
x_shape = x['shape']
|
||||
x_shp_len = len(x_shape)
|
||||
begin_v, end_v, strides_v = begin['value'], end['value'], strides['value']
|
||||
validator.check_const_input("begin", begin_v)
|
||||
validator.check_const_input("end", end_v)
|
||||
validator.check_const_input("strides", strides_v)
|
||||
validator.check_type("begin", begin['value'], [tuple])
|
||||
validator.check_type("end", end['value'], [tuple])
|
||||
validator.check_type("strides", strides['value'], [tuple])
|
||||
if len(begin_v) != x_shp_len or len(end_v) != x_shp_len or len(strides_v) != x_shp_len:
|
||||
raise ValueError(f"The length of begin index{begin_v}, end index{end_v} and strides{strides_v} "
|
||||
f"must be equal to the dims({x_shp_len}) of input.")
|
||||
|
||||
ret_shape = []
|
||||
append_dimensions = []
|
||||
shrink_pos = bin(self.shrink_axis_mask)[::-1]
|
||||
|
|
|
@ -372,7 +372,7 @@ test_case_math_ops = [
|
|||
'desc_bprop': [[3]]}),
|
||||
('TruncatedNormal', {
|
||||
'block': P.TruncatedNormal(),
|
||||
'desc_const': [Tensor(np.array([1, 2, 3]))],
|
||||
'desc_const': [[1, 2, 3]],
|
||||
'desc_inputs': [],
|
||||
'skip': ['backward'],
|
||||
'add_fake_input': True}),
|
||||
|
|
|
@ -52,8 +52,9 @@ class NetWorkSliceEllipsis(Cell):
|
|||
def construct(self, tensor):
|
||||
ret0 = tensor[0:4:2, ..., 1] + self.tensor_ret0
|
||||
ret1 = tensor[...] + self.tensor_ret1
|
||||
ret2 = tensor[True] + self.tensor_ret2
|
||||
return ret0, ret1, ret2
|
||||
ret2 = tensor[None] + self.tensor_ret2
|
||||
ret3 = tensor[True] + self.tensor_ret2
|
||||
return ret0, ret1, ret2, ret3
|
||||
|
||||
|
||||
class NetWorkReduceDimension(Cell):
|
||||
|
@ -305,7 +306,7 @@ test_cases = [
|
|||
'block': NetWorkReduceToScalar(),
|
||||
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
|
||||
}),
|
||||
('NetWorkSliceEllipsis', {
|
||||
('TensorSliceEllipsis', {
|
||||
'block': NetWorkSliceEllipsis(),
|
||||
'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))],
|
||||
}),
|
||||
|
|
Loading…
Reference in New Issue