diff --git a/mindspore/ccsrc/operator/composite/composite.cc b/mindspore/ccsrc/operator/composite/composite.cc index 11ab31a292d..88db8b8ff84 100644 --- a/mindspore/ccsrc/operator/composite/composite.cc +++ b/mindspore/ccsrc/operator/composite/composite.cc @@ -1084,6 +1084,7 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple, std::vector shrink; auto slice_tuple_eles = slice_tuple->elements(); size_t ellipsis_num = 0; + for (size_t index = 0; index < slice_tuple_size; index++) { if (slice_tuple_eles[index]->isa()) { AbstractSlicePtr slice = dyn_cast(slice_tuple_eles[index]); @@ -1118,12 +1119,13 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple, << slice_tuple_eles[index]->ToString(); } - for (size_t index = slice_tuple_size; index < shape_size; index++) { - begin->push_back(0); - end->push_back(shape[index]); - strides->push_back(1); + if (ellipsis_num == 0) { + for (size_t index = slice_tuple_size; index < shape_size; index++) { + begin->push_back(0); + end->push_back(shape[index]); + strides->push_back(1); + } } - return ConvertBinaryToDecimal(shrink); } @@ -1199,6 +1201,7 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec if (scalar_ptr->BuildValue()->cast()->value()) { return ExpandADim(ret_graph, tensor_node); } + MS_LOG(EXCEPTION) << "TensorSlice not support the index is False."; } shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides); } else if (args_spec_list[1]->isa()) { diff --git a/mindspore/ccsrc/pipeline/pipeline_ge.cc b/mindspore/ccsrc/pipeline/pipeline_ge.cc index 1da85b56995..3d4b8b3e2a6 100644 --- a/mindspore/ccsrc/pipeline/pipeline_ge.cc +++ b/mindspore/ccsrc/pipeline/pipeline_ge.cc @@ -319,19 +319,24 @@ void RunGEInitGraph(const py::dict &init_params, const std::string &phase) { py::object ExtractGeneralCnodeRet(const AbstractBasePtr &cnode_data, const py::tuple &data, size_t *count) { MS_EXCEPTION_IF_NULL(cnode_data); - if (*count >= data.size()) { - MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size() - << " less than the number of elements required. "; - } if (cnode_data->isa()) { - BaseShapePtr shape = cnode_data->BuildShape(); - auto shape_act = shape->cast()->shape(); - Tensor tensor_exp = py::cast(data[*count]); - if (shape_act != tensor_exp.shape()) { - MS_LOG(EXCEPTION) << "The shape of the tensor returned from GE is not the same as " - "the shape of the tensor derived from ME."; + if (*count >= data.size()) { + MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size() + << " less than the number of elements required. "; } + + BaseShapePtr shape = cnode_data->BuildShape(); + if (!shape->isa()) { + MS_LOG(EXCEPTION) << "The shape of the tensor derived is not Shape, is " << shape->ToString(); + } + auto shape_me = shape->cast()->shape(); + auto shape_ge = py::cast(data[*count]).shape(); + if (shape_ge != shape_me) { + MS_LOG(EXCEPTION) << "The shape of the " << *count << "th tensor returned: " << shape_ge + << " is not the same as the shape of the tensor derived: " << shape_me; + } + return data[(*count)++]; } @@ -357,11 +362,11 @@ py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data, return ValuePtrToPyData(GetValueNode(output_node)); } - if (*count >= data.size()) { - MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size() - << " less than the number of elements required. "; - } if (output_node->isa()) { + if (*count >= data.size()) { + MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size() + << " less than the number of elements required. "; + } return data[(*count)++]; } diff --git a/mindspore/ops/composite/multitype_ops/getitem_impl.py b/mindspore/ops/composite/multitype_ops/getitem_impl.py index 56617c06a83..540dd28b374 100644 --- a/mindspore/ops/composite/multitype_ops/getitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/getitem_impl.py @@ -147,6 +147,21 @@ def _tensor_getitem_by_number(data, number_index): return _tensor_slice(data, number_index) +@getitem.register("Tensor", "None") +def _tensor_getitem_by_none(data, index): + """ + Getting item of tensor by None. + + Inputs: + data (Tensor): A tensor. + index (None): None. + + Outputs: + Tensor, element type is as same as the element type of data. + """ + return _tensor_slice(data, index) + + @getitem.register("Tensor", "Slice") def _tensor_getitem_by_slice(data, slice_index): """ diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 43398a5f296..6f51dd0a1cd 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -633,7 +633,7 @@ class TruncatedNormal(PrimitiveWithInfer): dtype (:class:`mindspore.dtype`): Data type. Default: mindspore.float32. Inputs: - - **shape** (Tensor) - Shape of output tensor. The shape is a 1-D tensor, and type is int. + - **shape** (tuple[int]) - Shape of output tensor, is a tuple of positive int. Outputs: Tensor, type of output tensor is same as attribute `dtype`. @@ -651,16 +651,10 @@ class TruncatedNormal(PrimitiveWithInfer): validator.check_typename('dtype', dtype, mstype.number_type) def __infer__(self, shape): - shape_t = shape['value'] - validator.check_subclass("shape", shape['dtype'], mstype.tensor) - shape_n = shape_t.asnumpy() - if shape_n.ndim != 1: - raise ValueError('The rank of input shape must be 1.') - if shape_n.dtype not in (np.int32, np.int64): - raise TypeError('The type of input shape must be int32 or int64.') - for i, item in enumerate(shape_n): - validator.check_integer(f"shape[{i}]", item.item(), 0, Rel.GT) - out = {'shape': tuple(shape_n), + shape_value = shape['value'] + for i, value in enumerate(shape_value): + validator.check_integer(f'{i}th value of shape', value, 0, Rel.GT) + out = {'shape': shape_value, 'dtype': mstype.tensor_type(self.dtype), 'value': None} return out @@ -1648,20 +1642,19 @@ class StridedSlice(PrimitiveWithInfer): validator.check_type('shrink_axis_mask', shrink_axis_mask, [int]) def __infer__(self, x, begin, end, strides): - begin_shape, end_shape, strides_shape = begin['shape'], end['shape'], strides['shape'] - if begin_shape != strides_shape or end_shape != strides_shape: - raise ValueError("The shape of begin, end and strides in 'StridedSlice' must be equal.") - - validator.check_const_input("begin", begin['value']) - validator.check_const_input("end", end['value']) - validator.check_const_input("strides", strides['value']) - validator.check_type("begin", begin['value'], [tuple]) - validator.check_type("end", end['value'], [tuple]) - validator.check_type("strides", strides['value'], [tuple]) - x_shape = x['shape'] x_shp_len = len(x_shape) begin_v, end_v, strides_v = begin['value'], end['value'], strides['value'] + validator.check_const_input("begin", begin_v) + validator.check_const_input("end", end_v) + validator.check_const_input("strides", strides_v) + validator.check_type("begin", begin['value'], [tuple]) + validator.check_type("end", end['value'], [tuple]) + validator.check_type("strides", strides['value'], [tuple]) + if len(begin_v) != x_shp_len or len(end_v) != x_shp_len or len(strides_v) != x_shp_len: + raise ValueError(f"The length of begin index{begin_v}, end index{end_v} and strides{strides_v} " + f"must be equal to the dims({x_shp_len}) of input.") + ret_shape = [] append_dimensions = [] shrink_pos = bin(self.shrink_axis_mask)[::-1] diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 8b14ea23664..22df3d1fd34 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -372,7 +372,7 @@ test_case_math_ops = [ 'desc_bprop': [[3]]}), ('TruncatedNormal', { 'block': P.TruncatedNormal(), - 'desc_const': [Tensor(np.array([1, 2, 3]))], + 'desc_const': [[1, 2, 3]], 'desc_inputs': [], 'skip': ['backward'], 'add_fake_input': True}), diff --git a/tests/ut/python/ops/test_tensor_slice.py b/tests/ut/python/ops/test_tensor_slice.py index 08ba143de82..b547fdc06b9 100644 --- a/tests/ut/python/ops/test_tensor_slice.py +++ b/tests/ut/python/ops/test_tensor_slice.py @@ -52,8 +52,9 @@ class NetWorkSliceEllipsis(Cell): def construct(self, tensor): ret0 = tensor[0:4:2, ..., 1] + self.tensor_ret0 ret1 = tensor[...] + self.tensor_ret1 - ret2 = tensor[True] + self.tensor_ret2 - return ret0, ret1, ret2 + ret2 = tensor[None] + self.tensor_ret2 + ret3 = tensor[True] + self.tensor_ret2 + return ret0, ret1, ret2, ret3 class NetWorkReduceDimension(Cell): @@ -305,7 +306,7 @@ test_cases = [ 'block': NetWorkReduceToScalar(), 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))], }), - ('NetWorkSliceEllipsis', { + ('TensorSliceEllipsis', { 'block': NetWorkSliceEllipsis(), 'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))], }),