support tensor get value by tensor index

support tensor set value by tensor index
This commit is contained in:
buxue 2020-05-12 11:02:43 +08:00
parent ca74e624e2
commit e490618db8
22 changed files with 1272 additions and 426 deletions

View File

@ -1172,6 +1172,12 @@ int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, co
return 1; return 1;
} }
FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) {
auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional");
ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph));
return ret_graph;
}
FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// slice a tensor // slice a tensor
// args: tensor, slice or slice tuple // args: tensor, slice or slice tuple
@ -1229,12 +1235,6 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
return ret_graph; return ret_graph;
} }
FuncGraphPtr TensorSlice::ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const {
auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional");
ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph));
return ret_graph;
}
FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// select indexed item // select indexed item
// args: tuple of items, index // args: tuple of items, index

View File

@ -206,8 +206,6 @@ class TensorSlice : public MetaFuncGraph {
MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph) MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; } friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; }
FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const;
}; };
using TensorSlicePtr = std::shared_ptr<TensorSlice>; using TensorSlicePtr = std::shared_ptr<TensorSlice>;

View File

@ -101,6 +101,7 @@ const char kNameReLU6[] = "ReLU6";
const char kNameReLU6Grad[] = "ReLU6Grad"; const char kNameReLU6Grad[] = "ReLU6Grad";
const char kNameElu[] = "Elu"; const char kNameElu[] = "Elu";
const char kNameEluGrad[] = "EluGrad"; const char kNameEluGrad[] = "EluGrad";
const char kNameScatterUpdate[] = "ScatterUpdate";
const char kNameScatterNdUpdate[] = "ScatterNdUpdate"; const char kNameScatterNdUpdate[] = "ScatterNdUpdate";
const char kNameScatterMax[] = "ScatterMax"; const char kNameScatterMax[] = "ScatterMax";
const char kNameNMSWithMask[] = "NMSWithMask"; const char kNameNMSWithMask[] = "NMSWithMask";
@ -256,6 +257,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)}, {string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)},
{string(kNameZerosLike), ADPT_DESC(ZerosLike)}, {string(kNameZerosLike), ADPT_DESC(ZerosLike)},
{string(kNameOnesLike), ADPT_DESC(OnesLike)}, {string(kNameOnesLike), ADPT_DESC(OnesLike)},
{string(kNameScatterUpdate), ADPT_DESC(ScatterUpdate)},
{string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)}, {string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)},
{string(kNameScatterMax), ADPT_DESC(ScatterMax)}, {string(kNameScatterMax), ADPT_DESC(ScatterMax)},
{string(kNameNMSWithMask), ADPT_DESC(NMSWithMask)}, {string(kNameNMSWithMask), ADPT_DESC(NMSWithMask)},

View File

@ -515,6 +515,11 @@ INPUT_MAP(Unpack) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits<int>())}, {"num", ATTR_DESC(num, AnyTraits<int>())}}; ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits<int>())}, {"num", ATTR_DESC(num, AnyTraits<int>())}};
DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}}; DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}};
// ScatterUpdate
INPUT_MAP(ScatterUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
ATTR_MAP(ScatterUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ScatterUpdate) = {{0, OUTPUT_DESC(var)}};
// ScatterNdUpdate // ScatterNdUpdate
INPUT_MAP(ScatterNdUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; INPUT_MAP(ScatterNdUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
ATTR_MAP(ScatterNdUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; ATTR_MAP(ScatterNdUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};

View File

@ -132,6 +132,8 @@ DECLARE_OP_ADAPTER(ZerosLike)
DECLARE_OP_USE_OUTPUT(ZerosLike) DECLARE_OP_USE_OUTPUT(ZerosLike)
DECLARE_OP_ADAPTER(OnesLike) DECLARE_OP_ADAPTER(OnesLike)
DECLARE_OP_USE_OUTPUT(OnesLike) DECLARE_OP_USE_OUTPUT(OnesLike)
DECLARE_OP_ADAPTER(ScatterUpdate)
DECLARE_OP_USE_OUTPUT(ScatterUpdate)
DECLARE_OP_ADAPTER(ScatterNdUpdate) DECLARE_OP_ADAPTER(ScatterNdUpdate)
DECLARE_OP_USE_OUTPUT(ScatterNdUpdate) DECLARE_OP_USE_OUTPUT(ScatterNdUpdate)
DECLARE_OP_ADAPTER(ScatterMax) DECLARE_OP_ADAPTER(ScatterMax)

View File

@ -178,13 +178,14 @@ from .bounding_box_encode import _bounding_box_encode_tbe
from .check_valid import _check_valid_tbe from .check_valid import _check_valid_tbe
from .iou import _iou_tbe from .iou import _iou_tbe
from .arg_max import _arg_max_tbe from .arg_max import _arg_max_tbe
from .nms_with_mask import nms_with_mask_op_info from .nms_with_mask import _nms_with_mask_tbe
from .random_choice_with_mask import random_choice_with_mask_op_info from .random_choice_with_mask import _random_choice_with_mask_tbe
from .sgd import sgd_op_info from .sgd import _sgd_tbe
from .lars_update import lars_update_op_info from .lars_update import _lars_update_tbe
from .bn_training_update_v2 import _bn_training_update_v2_tbe from .bn_training_update_v2 import _bn_training_update_v2_tbe
from .square_sum_all import square_sum_all_op_info from .square_sum_all import _square_sum_all_tbe
from .pack import _pack_tbe from .pack import _pack_tbe
from .unpack import _unpack_tbe from .unpack import _unpack_tbe
from .scatter_update import _scatter_update_tbe
from .prelu import _prelu_tbe from .prelu import _prelu_tbe
from .prelu_grad import _prelu_grad_tbe from .prelu_grad import _prelu_grad_tbe

View File

@ -0,0 +1,42 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ScatterUpdate op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
scatter_update_op_info = TBERegOp("ScatterUpdate") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("scatter_update.so") \
.compute_cost(10) \
.kernel_name("scatter_update") \
.partial_flag(True) \
.attr("use_locking", "optional", "bool", "all") \
.input(0, "var", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.input(1, "updates", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(scatter_update_op_info)
def _scatter_update_tbe():
"""ScatterUpdate TBE register"""
return

View File

@ -14,6 +14,6 @@
# ============================================================================ # ============================================================================
"""ops utils.""" """ops utils."""
from .utils import _get_broadcast_shape, _get_concat_offset from .utils import get_broadcast_shape, get_concat_offset
__all__ = ['_get_broadcast_shape', '_get_concat_offset'] __all__ = ['get_broadcast_shape', 'get_concat_offset']

View File

@ -19,7 +19,8 @@ from ..._checkparam import Validator as validator
from ..._checkparam import Rel from ..._checkparam import Rel
from ...common import dtype as mstype from ...common import dtype as mstype
def _get_broadcast_shape(x_shape, y_shape, prim_name):
def get_broadcast_shape(x_shape, y_shape, prim_name):
""" """
Doing broadcast between tensor x and tensor y. Doing broadcast between tensor x and tensor y.
@ -37,7 +38,7 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name):
Examples: Examples:
>>> x_shape = [1, 2, 3] >>> x_shape = [1, 2, 3]
>>> y_shape = [1, 2] >>> y_shape = [1, 2]
>>> broadcast_shape = _get_broadcast_shape(x_shape, y_shape) >>> broadcast_shape = get_broadcast_shape(x_shape, y_shape)
""" """
if x_shape == y_shape: if x_shape == y_shape:
return x_shape return x_shape
@ -54,15 +55,14 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name):
elif x_shape[i] == y_shape[i]: elif x_shape[i] == y_shape[i]:
broadcast_shape_back.append(x_shape[i]) broadcast_shape_back.append(x_shape[i])
else: else:
raise ValueError("For '{}' the x_shape {} and y_shape {} can not broadcast.".format( raise ValueError(f"For '{prim_name}', the x_shape {x_shape} and y_shape {y_shape} can not broadcast.")
prim_name, x_shape, y_shape))
broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length] broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
broadcast_shape = broadcast_shape_front + broadcast_shape_back broadcast_shape = list(broadcast_shape_front) + broadcast_shape_back
return broadcast_shape return broadcast_shape
def _get_concat_offset(x_shp, x_type, axis, prim_name): def get_concat_offset(x_shp, x_type, axis, prim_name):
"""for concat and concatoffset check args and compute offset""" """for concat and concatoffset check args and compute offset"""
validator.check_value_type("shape", x_shp, [tuple], prim_name) validator.check_value_type("shape", x_shp, [tuple], prim_name)
validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name) validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name)
@ -73,7 +73,7 @@ def _get_concat_offset(x_shp, x_type, axis, prim_name):
if axis < 0: if axis < 0:
axis = axis + rank_base axis = axis + rank_base
all_shp = x_shp[0][axis] all_shp = x_shp[0][axis]
offset = [0,] offset = [0]
for i in range(1, len(x_shp)): for i in range(1, len(x_shp)):
v = x_shp[i] v = x_shp[i]
validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name) validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name)

View File

@ -1,226 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""constexpr util"""
from functools import reduce
import numpy as np
from ...primitive import constexpr
from ....common.tensor import Tensor
from ....common import dtype as mstype
from ...._extends.utils import Slice, Ellipsis_
@constexpr
def check_equal(param1, param2, msg="{},{}"):
"""Checks whether the two parameters are equal or not."""
if param1 != param2:
raise ValueError(msg.format(param1, param2))
return param1
@constexpr
def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
"""Checks the shape and size of the sensor and value."""
if data_shape == value_shape or data_size == value_size or value_size == 1:
return True
raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(value_shape, data_shape))
@constexpr
def check_tensor_setitem_index(index, element_type=None):
"""Checks tuple index type of tensor assignment."""
if index is None:
raise IndexError("Tensor's index cannot be None.")
# eg. Tensor[Slice] = u
if isinstance(index, Slice):
return True
# eg. Tensor[tuple] = u
if isinstance(index, tuple):
if not index:
raise IndexError("Tensor's index cannot be empty.")
# eg. Tensor[tuple(Slice...)] = u
if isinstance(index[0], (Slice, Ellipsis_, int)):
return True
raise IndexError("Index of type '{}' is not supported yet.".format(type(index[0])))
# eg. Tensor[Tensor[dtype=bool]] = u
if index == mstype.tensor:
if element_type is None or element_type != mstype.bool_:
raise TypeError(
"The index of tensor should be a bool type tensor. "
"{} type is not supported yet.".format(element_type))
return True
raise IndexError("Index of type '{}' is not supported yet.".format(type(index)))
@constexpr
def is_same_type(inst, type_):
"""
Checks whether an object is an instance of a target type.
Inputs:
inst (mindspore.dtype): Inspected type.
type_ (mindspore.dtype): Target type.
Outputs:
bool, the check result.
"""
return inst == type_
def slice_expand(input_slices, shape):
"""
Converts slice to indices.
Inputs:
slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
shape (tuple): The shape of a sensor is an integer element tuple.
Outputs:
tuple[list], This is expressed as (begins, ends, strides).
"""
begin = []
end = []
strides = []
index = 0
slices = None
# Slice or tuple(Slice...)
if isinstance(input_slices, Slice):
slices = (input_slices,)
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (Slice, Ellipsis_)):
is_have_ellipsis = False
for _, element in enumerate(input_slices):
if isinstance(element, Ellipsis_):
is_have_ellipsis = True
break
if is_have_ellipsis:
slices = ellipsis2slice(input_slices, shape)
else:
slices = input_slices
else:
raise IndexError("Tensor's index type is not supported yet.")
for s in slices:
start = 0 if (s.start is None) else s.start
stop = shape[index] if (s.end is None) else s.end
step = 1 if (s.step is None) else s.step
begin.append(start)
end.append(stop)
strides.append(step)
index += 1
while index < len(shape):
begin.append(0)
end.append(shape[index])
strides.append(1)
index += 1
return begin, end, strides
def ellipsis2slice(input_, shape):
"""Converts ellipsis to slice."""
input_slice = input_
result = []
if isinstance(input_, Ellipsis_):
input_slice = (input_,)
ell_count = 0
for _, element in enumerate(input_slice):
if not isinstance(element, Ellipsis_):
result.append(element)
continue
ell_count += 1
if ell_count > 1:
raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, "
"but it is currently {}".format(input_slice))
for _ in range(len(shape) - len(input_slice) + 1):
result.append(Slice(None, None, None))
return tuple(result)
@constexpr
def slice2indices(input_slices, shape):
"""
Converts slice to indices.
Inputs:
slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
shape (tuple): The shape of a tensor is an integer element tuple.
Outputs:
Tensor, the shape is (n, 1).
"""
begin, end, strides = slice_expand(input_slices, shape)
np_r = []
for i, element in enumerate(shape):
s = begin[i] if (begin[i] >= 0) else (element + begin[i])
e = end[i] if (end[i] >= 0) else (element + end[i])
np_r.append(np.r_[s:e:strides[i]])
# Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape)
np_ix = np.ix_(*np_r)
ravel = np.ravel_multi_index(np_ix, shape)
ravel = Tensor(ravel.reshape(-1, 1), dtype=mstype.int32)
return ravel
@constexpr
def check_indices(indices_size, index):
"""Checks indices whether is empty."""
if indices_size < 1:
raise IndexError("The tensor's index is unreasonable. index:{}".format(index))
return indices_size
@constexpr
def check_indices_value_size(indices_size, value_size):
"""Checks if the sizes are already matched."""
if value_size < 1:
raise ValueError("The value assigned to tensor cannot be empty.")
if value_size > 1:
if value_size != indices_size:
raise ValueError(
"The value given to tensor does not match the index size,"
" value size:{}, indics size:{}".format(value_size, indices_size))
return value_size
@constexpr
def integer_to_indices(index, shape):
"""Converts int or tuple[int] to indices."""
size = reduce(lambda x, y: x * y, shape)
range_ = np.arange(size).reshape(shape)
value = range_[index]
value = value.reshape(-1, 1)
return Tensor(value, dtype=mstype.int32)
@constexpr
def tuple_element_is_slice(indexs):
"""Judges tuple element type."""
if not indexs:
raise IndexError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple):
for _, ele in enumerate(indexs):
if not isinstance(ele, Slice):
return False
return True
return False
@constexpr
def tuple_element_is_int(indexs):
"""Judges tuple element type."""
if not indexs:
raise IndexError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple):
for _, ele in enumerate(indexs):
if not isinstance(ele, int):
return False
return True
return False

View File

@ -0,0 +1,487 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""constexpr util"""
from functools import reduce
import numpy as np
from ...primitive import constexpr
from ....common.tensor import Tensor
from ....common import dtype as mstype
from ...._extends.utils import Slice, Ellipsis_
from ....ops import _utils as op_utils
from ...composite import base
from .... import log as logger
from ... import functional as F
from ... import operations as P
hyper_map = base.HyperMap()
pack = P.Pack(axis=-1)
ALL_TENSOR = 0
NO_TENSOR = 1
CONTAIN_TENSOR = 2
ALL_SCALAR = 3
INT_ = 0
BOOL_ = 1
UNSUPPORTED_DTYPE = 2
TENSOR_SETITEM = "tensor setitem"
TENSOR_GETITEM = "tensor getitem"
SET_ITEM_BY_ONE_TENSOR = 0
SET_ITEM_BY_TUPLE_OF_TENSOR = 1
@constexpr
def check_equal(param1, param2, msg="{},{}"):
"""Checks whether the two parameters are equal or not."""
if param1 != param2:
raise ValueError(msg.format(param1, param2))
return param1
@constexpr
def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
"""Checks the shape and size of the sensor and value."""
if data_shape == value_shape or data_size == value_size or value_size == 1:
return True
raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(value_shape, data_shape))
@constexpr
def check_tensor_setitem_index(index, element_type=None):
"""Checks tuple index type of tensor assignment."""
if index is None:
raise IndexError("Tensor's index cannot be None.")
# eg. Tensor[Slice] = u
if isinstance(index, Slice):
return True
# eg. Tensor[tuple] = u
if isinstance(index, tuple):
if not index:
raise IndexError("Tensor's index cannot be empty.")
# eg. Tensor[tuple(Slice...)] = u
if isinstance(index[0], (Slice, Ellipsis_, int)):
return True
raise IndexError("Index of type '{}' is not supported yet.".format(type(index[0])))
# eg. Tensor[Tensor[dtype=bool]] = u
if isinstance(index, mstype.tensor_type):
if element_type is None or element_type != mstype.bool_:
raise TypeError(
"The index of tensor should be a bool type tensor. "
"{} type is not supported yet.".format(element_type))
return True
raise IndexError("Index of type '{}' is not supported yet.".format(type(index)))
@constexpr
def is_same_type(inst, type_):
"""
Checks whether an object is an instance of a target type.
Inputs:
inst (mindspore.dtype): Inspected type.
type_ (mindspore.dtype): Target type.
Outputs:
bool, the check result.
"""
return inst == type_
def slice_expand(input_slices, shape):
"""
Converts slice to indices.
Inputs:
slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
shape (tuple): The shape of a sensor is an integer element tuple.
Outputs:
tuple[list], This is expressed as (begins, ends, strides).
"""
begin = []
end = []
strides = []
index = 0
slices = None
# Slice or tuple(Slice...)
if isinstance(input_slices, Slice):
slices = (input_slices,)
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (Slice, Ellipsis_)):
is_have_ellipsis = False
for _, element in enumerate(input_slices):
if isinstance(element, Ellipsis_):
is_have_ellipsis = True
break
if is_have_ellipsis:
slices = ellipsis2slice(input_slices, shape)
else:
slices = input_slices
else:
raise IndexError("Tensor's index type is not supported yet.")
for s in slices:
start = 0 if (s.start is None) else s.start
stop = shape[index] if (s.end is None) else s.end
step = 1 if (s.step is None) else s.step
begin.append(start)
end.append(stop)
strides.append(step)
index += 1
while index < len(shape):
begin.append(0)
end.append(shape[index])
strides.append(1)
index += 1
return begin, end, strides
def ellipsis2slice(input_, shape):
"""Converts ellipsis to slice."""
input_slice = input_
result = []
if isinstance(input_, Ellipsis_):
input_slice = (input_,)
ell_count = 0
for _, element in enumerate(input_slice):
if not isinstance(element, Ellipsis_):
result.append(element)
continue
ell_count += 1
if ell_count > 1:
raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, "
"but it is currently {}".format(input_slice))
for _ in range(len(shape) - len(input_slice) + 1):
result.append(Slice(None, None, None))
return tuple(result)
@constexpr
def slice2indices(input_slices, shape):
"""
Converts slice to indices.
Inputs:
slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
shape (tuple): The shape of a tensor is an integer element tuple.
Outputs:
Tensor, the shape is (n, 1).
"""
begin, end, strides = slice_expand(input_slices, shape)
np_r = []
for i, element in enumerate(shape):
s = begin[i] if (begin[i] >= 0) else (element + begin[i])
e = end[i] if (end[i] >= 0) else (element + end[i])
np_r.append(np.r_[s:e:strides[i]])
# Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape)
np_ix = np.ix_(*np_r)
ravel = np.ravel_multi_index(np_ix, shape)
ravel = Tensor(ravel.reshape(-1, 1), dtype=mstype.int32)
return ravel
@constexpr
def check_indices(indices_size, index):
"""Checks indices whether is empty."""
if indices_size < 1:
raise IndexError("The tensor's index is unreasonable. index:{}".format(index))
return indices_size
@constexpr
def check_indices_value_size(indices_size, value_size):
"""Checks if the sizes are already matched."""
if value_size < 1:
raise ValueError("The value assigned to tensor cannot be empty.")
if value_size > 1:
if value_size != indices_size:
raise ValueError(
"The value given to tensor does not match the index size,"
" value size:{}, indics size:{}".format(value_size, indices_size))
return value_size
@constexpr
def integer_to_indices(index, shape):
"""Converts int or tuple[int] to indices."""
size = reduce(lambda x, y: x * y, shape)
range_ = np.arange(size).reshape(shape)
value = range_[index]
value = value.reshape(-1, 1)
return Tensor(value, dtype=mstype.int32)
@constexpr
def tuple_element_is_slice(indexs):
"""Judges tuple element type."""
if not indexs:
raise IndexError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple):
for _, ele in enumerate(indexs):
if not isinstance(ele, Slice):
return False
return True
return False
@constexpr
def tuple_element_is_int(indexs):
"""Judges tuple element type."""
if not indexs:
raise IndexError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple):
for _, ele in enumerate(indexs):
if not isinstance(ele, int):
return False
return True
return False
@constexpr
def tuple_elements_type(types):
"""Judges the type of all elements of the tuple."""
tensors_number = 0
for ele in types:
if isinstance(ele, mstype.tensor_type):
tensors_number += 1
if tensors_number == len(types):
return ALL_TENSOR
if tensors_number == 0:
return NO_TENSOR
return CONTAIN_TENSOR
@constexpr
def check_value_elements(data_dtype, types):
"""Judges the type of all elements of the tuple."""
tensors_number = 0
scalars_number = 0
for i, ele in enumerate(types):
if isinstance(ele, mstype.tensor_type):
ele_dtype = ele.element_type()
if data_dtype == ele_dtype:
tensors_number += 1
else:
raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' "
f"in value tuple is not consistent with origin tensor data type '{data_dtype}'.")
elif mstype.issubclass_(ele, data_dtype):
scalars_number += 1
else:
raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in "
f"value tuple is not consistent with origin tensor data type '{data_dtype}'.")
if tensors_number == len(types):
return ALL_TENSOR
if scalars_number == len(types):
return ALL_SCALAR
raise TypeError(f"For '{TENSOR_SETITEM}', the value does not support scalar and tensor mixing, but got {types}.")
@constexpr
def get_index_tensor_dtype(dtype):
"""Check a tuple of tensor data type."""
if dtype == mstype.int32:
return INT_
if dtype == mstype.bool_:
return BOOL_
raise TypeError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.")
@constexpr
def check_index_tensors_dtype(dtypes, op_name):
"""Check a tuple of tensor data type."""
if op_name == TENSOR_GETITEM:
valid_dtypes = (mstype.int32, mstype.int64)
elif op_name == TENSOR_SETITEM:
valid_dtypes = (mstype.int32,)
else:
raise ValueError("Unsupported operation.")
for ele in dtypes:
if ele in valid_dtypes and ele == dtypes[0]:
continue
raise TypeError(f"For '{op_name}', the index tensors data type must be same, "
f"and should be one of the following: {valid_dtypes}, but got {dtypes}.")
return True
@constexpr
def check_tensor_dtype_valid(dtype, valid_dtypes):
"""Check a tensor data type."""
if dtype in valid_dtypes:
return True
raise TypeError(f"The index tensor data type must be one of "
f"the following: {valid_dtypes}, but got {dtype}.")
@constexpr
def check_tensors_dtype_same(x_dtype, y_dtype, op_name):
"""Check tensors data type same."""
if x_dtype == y_dtype:
return True
raise TypeError(f"For '{op_name}', the value data type '{y_dtype}' "
f"is not consistent with origin tensor data type {x_dtype}.")
@constexpr
def broadcast_shapes(shapes, op_name):
"""Broadcasts a tuple of tensor."""
broadcast_shape = shapes[0]
for i, shape in enumerate(shapes):
logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.")
broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name)
return tuple(broadcast_shape)
@constexpr
def check_two_shapes_need_broadcast(shape_x, shape_y):
"""Check two shapes need broadcast."""
error = ValueError(f"For 'tensor setitem with tensor', the value tensor shape "
f"{shape_y} could not broadcast the required updates shape {shape_x}.")
if len(shape_y) > len(shape_x):
raise error
for i in range(-len(shape_y), 0):
if shape_y[i] > shape_x[i]:
raise error
if shape_y[i] < shape_x[i] and shape_y[i] != 1:
raise error
if shape_y == shape_x:
return False
return True
@constexpr
def compute_multiples(origin_shape, broadcast_shape):
"""Compute multiples between broadcast_shape with origin_shape."""
len_gap = len(broadcast_shape) - len(origin_shape)
return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape))
def tile(broadcast_shape, x):
multiples = compute_multiples(F.shape(x), broadcast_shape)
return F.tile(x, multiples)
@constexpr
def check_shapes_same(value_shapes, op_name):
"""Check if the shapes in the tuple are consistent."""
for i, shape in enumerate(value_shapes):
if shape != value_shapes[0]:
raise ValueError(f"For '{op_name}', the {i}th tensor shape in value tuple "
f"is not same as the first tensor shape.")
return True
@constexpr
def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type):
"""Convert a scalar to a tensor."""
if op_type == SET_ITEM_BY_ONE_TENSOR:
updates_shape = indices_shape + data_shape[1:]
else:
updates_shape = indices_shape[:-1] + data_shape[indices_shape[-1]:]
if isinstance(value, mstype.dtype_to_pytype(data_dtype)):
return Tensor(np.full(updates_shape, value), dtype=data_dtype)
raise TypeError(f"For '{TENSOR_SETITEM}', the value type '{value.__class__.__name__}'"
f" is not consistent with tensor data type {data_dtype}.")
@constexpr
def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type):
"""Convert a tuple of scalar to a tensor."""
updates_shape = generate_updates_shape(data_shape, index_shape, op_type)
if len(value) != updates_shape[-1]:
raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} in the updates tuple "
f"does not meet the requirements: {updates_shape[-1]}.")
array = np.array(value, dtype=mstype.dtype_to_nptype(data_dtype))
reps = compute_multiples(updates_shape[-1:], updates_shape)
return Tensor(np.tile(array, reps))
@constexpr
def generate_updates_shape(data_shape, index_shape, op_type):
"""Generate updates shape for 'tensor setitem'."""
if op_type == SET_ITEM_BY_ONE_TENSOR:
updates_shape = index_shape + data_shape[1:]
else:
updates_shape = index_shape[:-1] + data_shape[index_shape[-1]:]
return updates_shape
@constexpr
def check_number_of_index_tensor(data_shape, tuple_len, op_name):
"""Check if the number of index tensor exceeds the dimension of the operated tensor."""
if tuple_len <= len(data_shape):
return True
raise IndexError(f"For '{op_name}', the number {tuple_len} of index tensor "
f"is greater than the dimension {len(data_shape)} of the operated tensor.")
def generate_indeices_from_tuple_of_tensor(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple of tensor."""
indices = None
check_index_tensor_number = check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name)
if check_index_tensor_number:
dtype_tuple = hyper_map(F.dtype, tuple_index)
check_dtypes = check_index_tensors_dtype(dtype_tuple, op_name)
if check_dtypes:
shape_tuple = hyper_map(F.shape, tuple_index)
broadcast_shape = broadcast_shapes(shape_tuple, op_name)
broadcast_tensors = hyper_map(F.partial(tile, broadcast_shape), tuple_index)
indices = pack(broadcast_tensors)
return indices
def generate_updates_from_scalar(data, indices, value, op_type):
"""Generate an updates tensor from a scalar."""
data_shape = F.shape(data)
indices_shape = F.shape(indices)
data_dtype = F.dtype(data)
return convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type)
def generate_updates_from_tuple(data, index, value, op_type):
"""Generate an updates tensor from a tuple."""
value_types = hyper_map(F.typeof, value)
data_dtype = F.dtype(data)
value_elements_type = check_value_elements(data_dtype, value_types)
if value_elements_type == ALL_TENSOR:
value_shapes = hyper_map(F.shape, value)
shapes_same = check_shapes_same(value_shapes, TENSOR_SETITEM)
if shapes_same:
value = F.pack(value)
return generate_updates_from_tensor(data, index, value, op_type)
data_shape = F.shape(data)
index_shape = F.shape(index)
return convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type)
def generate_updates_from_tensor(data, index, value, op_type):
"""Generate an updates tensor from a tensor."""
data_shape = F.shape(data)
index_shape = F.shape(index)
value_shape = F.shape(value)
data_dtype = F.dtype(data)
value_dtype = F.dtype(value)
updates_shape = value_shape
check_dtype_same = check_tensors_dtype_same(data_dtype, value_dtype, TENSOR_SETITEM)
if check_dtype_same:
updates_shape = generate_updates_shape(data_shape, index_shape, op_type)
need_broadcast = check_two_shapes_need_broadcast(updates_shape, value_shape)
if need_broadcast:
return tile(updates_shape, value)
return value

View File

@ -15,9 +15,10 @@
"""Implementation for getitem.""" """Implementation for getitem."""
from ...composite import base from . import _utils as multi_utils
from ..import base
from ... import functional as F from ... import functional as F
from ....common import dtype as mstype
getitem = base.MultitypeFuncGraph('getitem') getitem = base.MultitypeFuncGraph('getitem')
""" """
@ -214,19 +215,45 @@ def _tensor_getitem_by_slice(data, slice_index):
return _tensor_slice(data, slice_index) return _tensor_slice(data, slice_index)
@getitem.register("Tensor", "Tensor")
def _tensor_getitem_by_tensor(data, tensor_index):
"""
Getting item of tensor by slice.
Inputs:
data (Tensor): A tensor.
tensor_index (Tensor): An index expressed by tensor.
Outputs:
Tensor, element type is same as the element type of data.
"""
check_dtypes = multi_utils.check_tensor_dtype_valid(F.dtype(tensor_index), (mstype.int32, mstype.int64))
result = None
if check_dtypes:
result = F.gather(data, tensor_index, 0)
return result
@getitem.register("Tensor", "Tuple") @getitem.register("Tensor", "Tuple")
def _tensor_getitem_by_slice_tuple(data, slice_tuple_index): def _tensor_getitem_by_tuple(data, tuple_index):
""" """
Getting item of tensor by slice tuple. Getting item of tensor by slice tuple.
Inputs: Inputs:
data (Tensor): A tensor. data (Tensor): A tensor.
slice_tuple_index (tuple): Index in tuple. tuple_index (tuple): Index in tuple.
Outputs: Outputs:
Tensor, element type is same as the element type of data. Tensor, element type is same as the element type of data.
""" """
return _tensor_slice(data, slice_tuple_index) index_types = multi_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = multi_utils.tuple_elements_type(index_types)
result = None
if index_elements_type == multi_utils.NO_TENSOR:
result = _tensor_slice(data, tuple_index)
if index_elements_type == multi_utils.ALL_TENSOR:
result = _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
return result
@getitem.register("Tensor", "Ellipsis") @getitem.register("Tensor", "Ellipsis")
@ -242,3 +269,10 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index):
Tensor, same as data. Tensor, same as data.
""" """
return _tensor_slice(data, ellipsis_index) return _tensor_slice(data, ellipsis_index)
def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
"""Tensor getitem by a tuple of tensor."""
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result

View File

@ -18,10 +18,11 @@
from ...composite import base from ...composite import base
from ....common import dtype as mstype from ....common import dtype as mstype
from ... import functional as F from ... import functional as F
from . import _multitype_ops_util as mult_util from . import _utils as multi_utils
setitem = base.MultitypeFuncGraph('setitem') setitem = base.MultitypeFuncGraph('setitem')
@setitem.register("List", "Number", "String") @setitem.register("List", "Number", "String")
def _list_setitem_with_string(data, number_index, value): def _list_setitem_with_string(data, number_index, value):
""" """
@ -118,7 +119,7 @@ def _dict_setitem_with_number(data, key, value):
@setitem.register("Tensor", "Tensor", "Tensor") @setitem.register("Tensor", "Tensor", "Tensor")
def _tensor_setitem_by_tensor_v1(data, index, value_tensor): def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
""" """
Tensor assignment. Tensor assignment.
@ -137,27 +138,15 @@ def _tensor_setitem_by_tensor_v1(data, index, value_tensor):
Outputs: Outputs:
Tensor, element type and shape is same as data. Tensor, element type and shape is same as data.
""" """
result = None
index_dtype = F.dtype(index) index_dtype = F.dtype(index)
index_shape = F.shape(index) tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype)
check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype) if tensor_dtype == multi_utils.INT_:
if check_result: return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
data_shape = F.shape(data) return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
data_shape = mult_util.check_equal(data_shape, index_shape,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
size = F.size(value_tensor)
size = mult_util.check_equal(1, size,
"When assign value is a tensor, its size should be {}, but current size is {}.")
dtype = F.dtype(data)
u_cast = F.cast(value_tensor, dtype)
one_data = F.ones_like(data)
u = F.tensor_mul(one_data, u_cast)
result = F.select(index, u, data)
return result
@setitem.register("Tensor", "Tensor", "Number") @setitem.register("Tensor", "Tensor", "Number")
def _tensor_setitem_by_tensor_v2(data, index, value): def _tensor_setitem_by_tensor_with_number(data, index, value):
""" """
Tensor assignment. Tensor assignment.
@ -171,22 +160,128 @@ def _tensor_setitem_by_tensor_v2(data, index, value):
Inputs: Inputs:
data (Tensor): Assigned tensor. data (Tensor): Assigned tensor.
index (Tensor): Tensor of bool type. index (Tensor): Tensor of bool type.
value_tensor (Number): Assignment value. value (Number): Assignment value.
Outputs: Outputs:
Tensor, element type and shape is same as data. Tensor, element type and shape is same as data.
""" """
result = None
index_dtype = F.dtype(index) index_dtype = F.dtype(index)
index_shape = F.shape(index) tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype)
check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype) if tensor_dtype == multi_utils.BOOL_:
if check_result: return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value)
shape = F.shape(data) return _tensor_setitem_by_int_tensor_with_scalar(data, index, value)
shape = mult_util.check_equal(
shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
dtype = F.dtype(data) @setitem.register("Tensor", "Tuple", "Number")
u = F.fill(dtype, shape, value) def _tensor_setitem_by_tuple_with_number(data, tuple_index, value):
result = F.select(index, u, data) """
Tensor assignment.
Note:
Syntax support: A[B, C, D] = u.
Restraint condition: 1) A is a Tensor, and B, C, D are index.
2) u is a scalar.
Inputs:
data (Tensor): Assigned tensor.
index (Tuple): An index tuple.
value (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = multi_utils.tuple_elements_type(index_types)
result = None
if index_elements_type == multi_utils.NO_TENSOR:
result = _tensor_assgin_number(data, tuple_index, value)
if index_elements_type == multi_utils.ALL_TENSOR:
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
updates = multi_utils.generate_updates_from_scalar(data, indices, value,
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
result = F.scatter_nd_update(data, indices, updates)
return result
@setitem.register("Tensor", "Tuple", "Tensor")
def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
"""
Tensor assignment.
Note:
Syntax support: A[B, C, D] = U.
Restraint condition: 1) A is a Tensor, and B, C, D are index Tensors.
2) U is a Tensor.
Inputs:
data (Tensor): Assigned tensor.
index (Tuple): An index tuple.
value (Tensor): Assignment tensor, should has the same data type as 'data'.
Outputs:
Tensor, element type and shape is same as data.
"""
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = multi_utils.tuple_elements_type(index_types)
result = None
if index_elements_type == multi_utils.NO_TENSOR:
result = _tensor_assgin_tensor(data, tuple_index, value)
if index_elements_type == multi_utils.ALL_TENSOR:
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
updates = multi_utils.generate_updates_from_tensor(data, indices, value,
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
result = F.scatter_nd_update(data, indices, updates)
return result
@setitem.register("Tensor", "Tuple", "Tuple")
def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
"""
Tensor assignment.
Note:
Syntax support: A[B, C, D] = U.
Restraint condition: 1) A is a Tensor, and B, C, D are index Tensors.
2) A B and C could be broadcast.
3) U is a Tensor.
Inputs:
data (Tensor): Assigned tensor.
index (Tuple): A tuple of tensor, these tensor could be broadcast.
value (Tensor): Assignment tensor, should has the same data type as 'data'.
Outputs:
Tensor, element type and shape is same as data.
"""
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = multi_utils.tuple_elements_type(index_types)
result = None
if index_elements_type == multi_utils.ALL_TENSOR:
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
updates = multi_utils.generate_updates_from_tuple(data, indices, value,
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
result = F.scatter_nd_update(data, indices, updates)
return result
@setitem.register("Tensor", "Tensor", "Tuple")
def _tensor_setitem_by_tensor_v2(data, index, value):
"""
Tensor assignment.
Inputs:
data (Tensor): Assigned tensor.
index (Tensor): Tensor of bool type.
value (Tuple): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
index_dtype = F.dtype(index)
check_dtype = multi_utils.check_tensor_dtype_valid(index_dtype, (mstype.int32, mstype.int64))
result = None
if check_dtype:
result = _tensor_setitem_by_tensor_with_tuple(data, index, value)
return result return result
@ -212,66 +307,6 @@ def _tensor_setitem_with_slice_v3(data, input_slice, value):
return _tensor_assgin_tensor(data, input_slice, value) return _tensor_assgin_tensor(data, input_slice, value)
@setitem.register("Tensor", "Tuple", "Tensor")
def _tensor_setitem_with_slice_v4(data, input_slice, value):
"""
Tensor assignment.
Note:
Syntax support: A[tuple(Slice)] = U, and A[tuple(Number)] = U
Restraint condition: A is a Tensor
Slice like "1:3, ::, :4:-1"
U is a Tensor(size=1) or Tensor(size>1)
Inputs:
data (Tensor): Assigned tensor.
input_slice (Union[tuple[Slice], tuple[Number]]): Slice expression.
value (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
return _tensor_assgin_tensor(data, input_slice, value)
def _tensor_assgin_tensor(data, input_slice, value):
"""Assigns a tensor value to the tensor by slice."""
result = None
check_result = mult_util.check_tensor_setitem_index(input_slice)
if check_result:
data_shape = F.shape(data)
indices = mult_util.slice2indices(input_slice, data_shape)
is_tuple_int = mult_util.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = mult_util.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
return result
def _tensor_indices_tensor(data, data_shape, index, indices, value):
"""Assigns a tensor value to the tensor."""
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = mult_util.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = None
value_size = F.size(value)
value_size = mult_util.check_indices_value_size(indices_size, value_size)
if value_size == 1:
value_fill = F.fill(data_dtype, (indices_size,), 1)
value = F.cast(value, data_dtype)
value_fill = F.tensor_mul(value_fill, value)
elif value_size > 1:
value_fill = F.reshape(value, (indices_size,))
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u, data)
@setitem.register("Tensor", "Slice", "Number") @setitem.register("Tensor", "Slice", "Number")
def _tensor_setitem_with_slice_v1(data, input_slice, value): def _tensor_setitem_with_slice_v1(data, input_slice, value):
""" """
@ -294,63 +329,25 @@ def _tensor_setitem_with_slice_v1(data, input_slice, value):
return _tensor_assgin_number(data, input_slice, value) return _tensor_assgin_number(data, input_slice, value)
@setitem.register("Tensor", "Tuple", "Number")
def _tensor_setitem_with_slice_v2(data, input_slice, value):
"""
Tensor assignment.
Note:
Syntax support: A[tuple(Slice)] = u, and A[tuple(Number)] = u
Restraint condition: A is a Tensor.
Slice like "1:3, ::, :4:-1"
u is a scalar
Inputs:
data (Tensor): Assigned tensor.
input_slice (Union[tuple[Slice], tuple[Number]]): slice expression.
value (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
return _tensor_assgin_number(data, input_slice, value)
def _tensor_assgin_number(data, input_slice, value): def _tensor_assgin_number(data, input_slice, value):
"""Givens a scalar assign to tensor by slice""" """Givens a scalar assign to tensor by slice"""
check_result = mult_util.check_tensor_setitem_index(input_slice) check_result = multi_utils.check_tensor_setitem_index(input_slice)
result = None result = None
if check_result: if check_result:
data_shape = F.shape(data) data_shape = F.shape(data)
indices = mult_util.slice2indices(input_slice, data_shape) indices = multi_utils.slice2indices(input_slice, data_shape)
is_tuple_int = mult_util.tuple_element_is_int(input_slice) is_tuple_int = multi_utils.tuple_element_is_int(input_slice)
if is_tuple_int: if is_tuple_int:
indices = mult_util.integer_to_indices(input_slice, data_shape) indices = multi_utils.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_number(data, data_shape, input_slice, indices, value) result = _tensor_indices_number(data, data_shape, input_slice, indices, value)
return result return result
def _tensor_indices_number(data, data_shape, index, indices, value):
"""Assigns a scalar value to the tensor."""
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = mult_util.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = F.fill(data_dtype, (indices_size,), value)
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u, data)
@setitem.register("Tensor", "Number", "Number") @setitem.register("Tensor", "Number", "Number")
def _tensor_setitem_with_int_v1(data, index, value): def _tensor_setitem_with_int_v1(data, index, value):
"""Syntax: A[1] = 3""" """Syntax: A[1] = 3"""
data_shape = F.shape(data) data_shape = F.shape(data)
indices = mult_util.integer_to_indices(index, data_shape) indices = multi_utils.integer_to_indices(index, data_shape)
return _tensor_indices_number(data, data_shape, index, indices, value) return _tensor_indices_number(data, data_shape, index, indices, value)
@ -358,7 +355,7 @@ def _tensor_setitem_with_int_v1(data, index, value):
def _tensor_setitem_with_int_v2(data, index, value): def _tensor_setitem_with_int_v2(data, index, value):
"""Syntax: A[1] = Tensor""" """Syntax: A[1] = Tensor"""
data_shape = F.shape(data) data_shape = F.shape(data)
indices = mult_util.integer_to_indices(index, data_shape) indices = multi_utils.integer_to_indices(index, data_shape)
return _tensor_indices_tensor(data, data_shape, index, indices, value) return _tensor_indices_tensor(data, data_shape, index, indices, value)
@ -379,7 +376,7 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value):
data_size = F.size(data) data_size = F.size(data)
value_shape = F.shape(value) value_shape = F.shape(value)
value_size = F.size(value) value_size = F.size(value)
check_result = mult_util.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size) check_result = multi_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
if check_result: if check_result:
if data_size == value_size: if data_size == value_size:
result = F.reshape(value, data_shape) result = F.reshape(value, data_shape)
@ -389,3 +386,108 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value):
param2 = F.cast(value, data_dtype) param2 = F.cast(value, data_dtype)
result = F.tensor_mul(param1, param2) result = F.tensor_mul(param1, param2)
return result return result
def _tensor_assgin_tensor(data, input_slice, value):
"""Assigns a tensor value to the tensor by slice."""
result = None
check_result = multi_utils.check_tensor_setitem_index(input_slice)
if check_result:
data_shape = F.shape(data)
indices = multi_utils.slice2indices(input_slice, data_shape)
is_tuple_int = multi_utils.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = multi_utils.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
return result
def _tensor_indices_tensor(data, data_shape, index, indices, value):
"""Assigns a tensor value to the tensor."""
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = multi_utils.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = None
value_size = F.size(value)
value_size = multi_utils.check_indices_value_size(indices_size, value_size)
if value_size == 1:
value_fill = F.fill(data_dtype, (indices_size,), 1)
value = F.cast(value, data_dtype)
value_fill = F.tensor_mul(value_fill, value)
elif value_size > 1:
value_fill = F.reshape(value, (indices_size,))
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u, data)
def _tensor_indices_number(data, data_shape, index, indices, value):
"""Assigns a scalar value to the tensor."""
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = multi_utils.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = F.fill(data_dtype, (indices_size,), value)
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u, data)
def _tensor_setitem_by_tensor_with_tuple(data, index, value):
"""Set a tensor item by a tensor with a tuple."""
updates = multi_utils.generate_updates_from_tuple(data, index, value,
multi_utils.SET_ITEM_BY_ONE_TENSOR)
result = F.scatter_update(data, index, updates)
return result
def _tensor_setitem_by_int_tensor_with_scalar(data, index, value):
"""Set a tensor item by a int tensor with a scalar."""
updates = multi_utils.generate_updates_from_scalar(data, index, value,
multi_utils.SET_ITEM_BY_ONE_TENSOR)
return F.scatter_update(data, index, updates)
def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):
"""Set a tensor item by a bool tensor with a scalar."""
index_shape = F.shape(index)
shape = F.shape(data)
shape = multi_utils.check_equal(
shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
dtype = F.dtype(data)
u = F.fill(dtype, shape, value)
return F.select(index, u, data)
def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
"""Set a tensor item by a int tensor with a tensor."""
updates = multi_utils.generate_updates_from_tensor(data, index, value,
multi_utils.SET_ITEM_BY_ONE_TENSOR)
return F.scatter_update(data, index, updates)
def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
"""Set a tensor item by a bool tensor with a tensor."""
index_shape = F.shape(index)
data_shape = F.shape(data)
data_shape = multi_utils.check_equal(data_shape, index_shape,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
size = F.size(value)
size = multi_utils.check_equal(1, size,
"When assign value is a tensor, its size should be {}, but current size is {}.")
dtype = F.dtype(data)
u_cast = F.cast(value, dtype)
one_data = F.ones_like(data)
u = F.tensor_mul(one_data, u_cast)
result = F.select(index, u, data)
return result

View File

@ -31,6 +31,7 @@ dtype = P.DType()
issubclass_ = P.IsSubClass() issubclass_ = P.IsSubClass()
isinstance_ = P.IsInstance() isinstance_ = P.IsInstance()
fill = P.Fill() fill = P.Fill()
tile = P.Tile()
select = P.Select() select = P.Select()
size = P.Size() size = P.Size()
ones_like = P.OnesLike() ones_like = P.OnesLike()
@ -70,6 +71,12 @@ scalar_cast = P.ScalarCast()
print_ = P.Print() print_ = P.Print()
expand_dims = P.ExpandDims() expand_dims = P.ExpandDims()
scatter_nd = P.ScatterNd() scatter_nd = P.ScatterNd()
gather = P.GatherV2()
gather_nd = P.GatherNd()
scatter_update = P.ScatterUpdate()
scatter_nd_update = P.ScatterNdUpdate()
pack = P.Pack()
tuple_setitem = Primitive('tuple_setitem') tuple_setitem = Primitive('tuple_setitem')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive('tuple_getitem')

View File

@ -24,7 +24,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Fill, GatherNd, GatherV2, InvertPermutation, Fill, GatherNd, GatherV2, InvertPermutation,
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
SameTypeShape, ScatterMax, SameTypeShape, ScatterMax, ScatterUpdate,
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split, Shape, Size, Slice, Split,
Squeeze, StridedSlice, Tile, Squeeze, StridedSlice, Tile,
@ -193,6 +193,7 @@ __all__ = [
'Pad', 'Pad',
'MirrorPad', 'MirrorPad',
'GatherNd', 'GatherNd',
'ScatterUpdate',
'ScatterNdUpdate', 'ScatterNdUpdate',
'Floor', 'Floor',
'NMSWithMask', 'NMSWithMask',

View File

@ -19,7 +19,7 @@ from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind from ..._c_expression import signature_kind as sig_kind
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from ..._checkparam import Validator as validator, Rel from ..._checkparam import Validator as validator, Rel
from .._utils import _get_concat_offset from .._utils import get_concat_offset
from ...common import dtype as mstype from ...common import dtype as mstype
@ -136,7 +136,7 @@ class ConcatOffset(PrimitiveWithInfer):
axis = self.axis axis = self.axis
x_shp = input_x['shape'] x_shp = input_x['shape']
x_type = input_x['dtype'] x_type = input_x['dtype']
offset, _, axis = _get_concat_offset(x_shp, x_type, axis, self.name) offset, _, axis = get_concat_offset(x_shp, x_type, axis, self.name)
self.add_prim_attr('T', x_type[0].element_type()) self.add_prim_attr('T', x_type[0].element_type())
offset_values = [] offset_values = []
for i in range(len(x_shp)): for i in range(len(x_shp)):

View File

@ -24,16 +24,15 @@ import itertools
import numbers import numbers
import numpy as np import numpy as np
from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
from ..._checkparam import Rel from ..._checkparam import Rel
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.tensor import Tensor from ...common.tensor import Tensor
from ..operations.math_ops import _infer_shape_reduce from ..operations.math_ops import _infer_shape_reduce
from .._utils import _get_concat_offset from .._utils import get_concat_offset
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
def _check_infer_attr_reduce(axis, keep_dims, prim_name): def _check_infer_attr_reduce(axis, keep_dims, prim_name):
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name) validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
validator.check_value_type('axis', axis, [int, tuple], prim_name) validator.check_value_type('axis', axis, [int, tuple], prim_name)
@ -931,7 +930,7 @@ class InvertPermutation(PrimitiveWithInfer):
z = [x_value[i] for i in range(len(x_value))] z = [x_value[i] for i in range(len(x_value))]
z.sort() z.sort()
y = [None]*len(x_value) y = [None] * len(x_value)
for i, value in enumerate(x_value): for i, value in enumerate(x_value):
validator.check_value_type("input[%d]" % i, value, [int], self.name) validator.check_value_type("input[%d]" % i, value, [int], self.name)
validator.check(f'value', z[i], f'index', i, Rel.EQ, self.name) validator.check(f'value', z[i], f'index', i, Rel.EQ, self.name)
@ -1111,6 +1110,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
>>> input_x = Tensor(np.random.rand(5)) >>> input_x = Tensor(np.random.rand(5))
>>> index, output = P.ArgMinWithValue()(input_x) >>> index, output = P.ArgMinWithValue()(input_x)
""" """
@prim_attr_register @prim_attr_register
def __init__(self, axis=0, keep_dims=False): def __init__(self, axis=0, keep_dims=False):
"""init ArgMinWithValue""" """init ArgMinWithValue"""
@ -1352,7 +1352,7 @@ class Concat(PrimitiveWithInfer):
axis = self.axis axis = self.axis
x_shp = input_x['shape'] x_shp = input_x['shape']
x_type = input_x['dtype'] x_type = input_x['dtype']
_, all_shp, _ = _get_concat_offset(x_shp, x_type, axis, self.name) _, all_shp, _ = get_concat_offset(x_shp, x_type, axis, self.name)
self.add_prim_attr('T', x_type[0].element_type()) self.add_prim_attr('T', x_type[0].element_type())
self.add_prim_attr('inputNums', len(x_shp)) self.add_prim_attr('inputNums', len(x_shp))
ret_shp = x_shp[0].copy() ret_shp = x_shp[0].copy()
@ -1376,15 +1376,13 @@ def _get_pack_shape(x_shape, x_type, axis, prim_name):
if axis < 0: if axis < 0:
axis = axis + rank_base + 1 axis = axis + rank_base + 1
for i in range(1, N): for i in range(1, N):
v = x_shape[i]
validator.check('len of x_shape[%d]' % i, len(v), 'len of rank_base', rank_base, Rel.EQ, prim_name)
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, prim_name, TypeError) validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, prim_name, TypeError)
for j in range(rank_base): if x_shape[i] != x_shape[0]:
if v[j] != x_shape[0][j]: raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not pack with first element")
raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not pack with first element")
out_shape.insert(axis, N) out_shape.insert(axis, N)
return out_shape return out_shape
class Pack(PrimitiveWithInfer): class Pack(PrimitiveWithInfer):
r""" r"""
Packs a list of tensors in specified axis. Packs a list of tensors in specified axis.
@ -1831,7 +1829,7 @@ class DiagPart(PrimitiveWithInfer):
return x_type return x_type
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
if len(x_shape)%2 != 0 or \ if len(x_shape) % 2 != 0 or \
not x_shape: not x_shape:
raise ValueError(f"For \'{self.name}\' input rank must be non-zero and even, but got rank {len(x_shape)}, " raise ValueError(f"For \'{self.name}\' input rank must be non-zero and even, but got rank {len(x_shape)}, "
f"with shapes {x_shape}") f"with shapes {x_shape}")
@ -2004,6 +2002,49 @@ class GatherNd(PrimitiveWithInfer):
return x_dtype return x_dtype
class ScatterUpdate(PrimitiveWithInfer):
"""
Update tensor value by using input indices and value.
Using given values to update tensor value, along with the input indices.
Args:
use_locking (bool): Whether protect the assignment by a lock. Default: True.
Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
- **indices** (Tensor) - The index of input tensor.
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
and update.shape = indices.shape + input_x.shape[1:].
Outputs:
Tensor, has the same shape and type as `input_x`.
Examples:
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32))
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> op = P.ScatterNdUpdate()
>>> output = op(input_x, indices, update)
"""
@prim_attr_register
def __init__(self, use_locking=True):
"""Init ScatterNdUpdate"""
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
def infer_shape(self, x_shape, indices_shape, value_shape):
if indices_shape + x_shape[1:] != value_shape:
raise ValueError('Input value are not match with input indices.')
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype
class ScatterNdUpdate(PrimitiveWithInfer): class ScatterNdUpdate(PrimitiveWithInfer):
""" """
Update tensor value by using input indices and value. Update tensor value by using input indices and value.
@ -2028,11 +2069,6 @@ class ScatterNdUpdate(PrimitiveWithInfer):
>>> op = P.ScatterNdUpdate() >>> op = P.ScatterNdUpdate()
>>> output = op(input_x, indices, update) >>> output = op(input_x, indices, update)
""" """
__mindspore_signature__ = (
('input_x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD),
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD)
)
@prim_attr_register @prim_attr_register
def __init__(self, use_locking=True): def __init__(self, use_locking=True):
@ -2142,10 +2178,10 @@ class SpaceToDepth(PrimitiveWithInfer):
validator.check('x dimension', len(x_shape), '', 4, Rel.EQ) validator.check('x dimension', len(x_shape), '', 4, Rel.EQ)
out_shape = copy.deepcopy(x_shape) out_shape = copy.deepcopy(x_shape)
for i in range(2): for i in range(2):
if out_shape[i+2] % self.block_size != 0: if out_shape[i + 2] % self.block_size != 0:
raise ValueError(f'For \'{self.name}\' input shape[{i+2}] {out_shape[i+2]} should be ' raise ValueError(f'For \'{self.name}\' input shape[{i + 2}] {out_shape[i + 2]} should be '
f'fully divided by block_size {self.block_size}') f'fully divided by block_size {self.block_size}')
out_shape[i+2] //= self.block_size out_shape[i + 2] //= self.block_size
out_shape[1] *= self.block_size * self.block_size out_shape[1] *= self.block_size * self.block_size
return out_shape return out_shape
@ -2199,9 +2235,10 @@ class DepthToSpace(PrimitiveWithInfer):
validator.check('x dimension', len(x_shape), '', 4, Rel.EQ) validator.check('x dimension', len(x_shape), '', 4, Rel.EQ)
out_shape = copy.deepcopy(x_shape) out_shape = copy.deepcopy(x_shape)
for i in range(2): for i in range(2):
out_shape[i+2] *= self.block_size out_shape[i + 2] *= self.block_size
validator.check_integer('x_shape[1] % (block_size*block_size)', x_shape[1] % (self.block_size*self.block_size), validator.check_integer('x_shape[1] % (block_size*block_size)',
x_shape[1] % (self.block_size * self.block_size),
0, Rel.EQ, self.name) 0, Rel.EQ, self.name)
out_shape[1] //= self.block_size * self.block_size out_shape[1] //= self.block_size * self.block_size
return out_shape return out_shape
@ -2251,6 +2288,7 @@ class SpaceToBatch(PrimitiveWithInfer):
[[[[1.]]], [[[2.]]], [[[3.]]], [[[4.]]]] [[[[1.]]], [[[2.]]], [[[3.]]], [[[4.]]]]
""" """
@prim_attr_register @prim_attr_register
def __init__(self, block_size, paddings): def __init__(self, block_size, paddings):
"""Init SpaceToBatch""" """Init SpaceToBatch"""
@ -2271,12 +2309,12 @@ class SpaceToBatch(PrimitiveWithInfer):
validator.check_integer('rank of input_x', len(x_shape), 4, Rel.EQ, self.name) validator.check_integer('rank of input_x', len(x_shape), 4, Rel.EQ, self.name)
out_shape = copy.deepcopy(x_shape) out_shape = copy.deepcopy(x_shape)
for i in range(2): for i in range(2):
padded = out_shape[i+2] + self.paddings[i][0] + \ padded = out_shape[i + 2] + self.paddings[i][0] + \
self.paddings[i][1] self.paddings[i][1]
if padded % self.block_size != 0: if padded % self.block_size != 0:
raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by ' raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by '
f'block_size {self.block_size}') f'block_size {self.block_size}')
out_shape[i+2] = padded // self.block_size out_shape[i + 2] = padded // self.block_size
out_shape[0] *= self.block_size * self.block_size out_shape[0] *= self.block_size * self.block_size
return out_shape return out_shape
@ -2319,6 +2357,7 @@ class BatchToSpace(PrimitiveWithInfer):
[[[[1., 2.], [3., 4.]]]] [[[[1., 2.], [3., 4.]]]]
""" """
@prim_attr_register @prim_attr_register
def __init__(self, block_size, crops): def __init__(self, block_size, crops):
"""Init BatchToSpace""" """Init BatchToSpace"""
@ -2339,10 +2378,10 @@ class BatchToSpace(PrimitiveWithInfer):
validator.check('rank of input_x', len(x_shape), '', 4) validator.check('rank of input_x', len(x_shape), '', 4)
out_shape = copy.deepcopy(x_shape) out_shape = copy.deepcopy(x_shape)
for i in range(2): for i in range(2):
x_block_prod = out_shape[i+2] * self.block_size x_block_prod = out_shape[i + 2] * self.block_size
crops_sum = self.crops[i][0] + self.crops[i][1] crops_sum = self.crops[i][0] + self.crops[i][1]
validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name) validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name)
out_shape[i+2] = x_block_prod - crops_sum out_shape[i + 2] = x_block_prod - crops_sum
block_size_prod = self.block_size * self.block_size block_size_prod = self.block_size * self.block_size
if out_shape[0] % block_size_prod != 0: if out_shape[0] % block_size_prod != 0:
raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by ' raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by '

View File

@ -24,7 +24,7 @@ from ..._checkparam import Validator as validator
from ..._checkparam import Rel from ..._checkparam import Rel
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.tensor import Tensor from ...common.tensor import Tensor
from .._utils import _get_broadcast_shape from .._utils import get_broadcast_shape
from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op
@ -75,7 +75,7 @@ class _BinaryOp(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
def infer_shape(self, x_shape, y_shape): def infer_shape(self, x_shape, y_shape):
return _get_broadcast_shape(x_shape, y_shape, self.name) return get_broadcast_shape(x_shape, y_shape, self.name)
class _MathBinaryOp(_BinaryOp): class _MathBinaryOp(_BinaryOp):

View File

@ -27,9 +27,15 @@ class IdentityEC(IExectorComponent):
def __call__(self): def __call__(self):
result_id = self.function[keyword.id] + '-' + self.inputs[keyword.id] result_id = self.function[keyword.id] + '-' + self.inputs[keyword.id]
group = self.function[keyword.group] + '-' + self.inputs[keyword.group] group = self.function[keyword.group] + '-' + self.inputs[keyword.group]
return { ret = {
keyword.id: result_id, keyword.id: result_id,
keyword.group: group, keyword.group: group,
keyword.desc_inputs: self.inputs[keyword.desc_inputs], keyword.desc_inputs: self.inputs[keyword.desc_inputs],
keyword.result: self.function[keyword.block](*self.inputs[keyword.desc_inputs]) keyword.result: self.function[keyword.block](*self.inputs[keyword.desc_inputs])
} }
print("buxue------------------------------------------------")
print("inputs")
print(ret[keyword.desc_inputs])
print("outputs")
print(ret[keyword.result])
return ret

View File

@ -1297,7 +1297,7 @@ raise_set = [
('ScatterNdUpdate', { ('ScatterNdUpdate', {
'block': (P.ScatterNdUpdate(), {'exception': TypeError}), 'block': (P.ScatterNdUpdate(), {'exception': TypeError}),
'desc_inputs': (Tensor(np.ones((2, 3), np.float32)), 'desc_inputs': (Tensor(np.ones((2, 3), np.float32)),
Tensor(np.ones((2, 2), np.int32)), Tensor(np.ones((2, 2), np.float32)),
Tensor(np.ones((2,), np.float32))), Tensor(np.ones((2,), np.float32))),
'desc_bprop': [[2, 3]]}), 'desc_bprop': [[2, 3]]}),
('Pack', { ('Pack', {

View File

@ -16,13 +16,14 @@
import numpy as np import numpy as np
import pytest import pytest
from mindspore import Tensor from mindspore import Tensor, Parameter
from mindspore import context from mindspore import context
from mindspore import dtype as mstype from mindspore import dtype as mstype
from mindspore.nn import Cell from mindspore.nn import Cell
from ....mindspore_test_framework.mindspore_test import mindspore_test from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \ from ....mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
class NetWorkSlicePositive(Cell): class NetWorkSlicePositive(Cell):
@ -145,6 +146,160 @@ class TensorAssignWithSlice(Cell):
return z return z
class TensorIndexByOneTensor(Cell):
def __init__(self):
super(TensorIndexByOneTensor, self).__init__()
self.const = Tensor(np.ones((5, 4, 7, 8)), mstype.int32)
def construct(self, x, index):
ret = x[index] + self.const
return ret
class TensorIndexByTwoTensors(Cell):
def __init__(self):
super(TensorIndexByTwoTensors, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 8)), mstype.int32)
def construct(self, x, index_0, index_1):
ret = x[index_0, index_1] + self.const
return ret
class TensorIndexByThreeTensors(Cell):
def __init__(self):
super(TensorIndexByThreeTensors, self).__init__()
self.const = Tensor(np.ones((5, 3, 4, 5)), mstype.int32)
def construct(self, x, index_0, index_1, index_2):
ret = x[index_0, index_1, index_2] + self.const
return ret
class TensorSetItemByOneTensorWithNumber(Cell):
def __init__(self, value):
super(TensorSetItemByOneTensorWithNumber, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.value = value
def construct(self, index):
self.param[index] = self.value
ret = self.param + self.const
return ret
class TensorSetItemByOneTensorWithTensor(Cell):
def __init__(self):
super(TensorSetItemByOneTensorWithTensor, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
def construct(self, index, value):
self.param[index] = value
ret = self.param + self.const
return ret
class TensorSetItemByOneTensorWithTupleOfNumber(Cell):
def __init__(self, value):
super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.value = value
def construct(self, index):
self.param[index] = self.value
ret = self.param + self.const
return ret
class TensorSetItemByOneTensorWithTupleOfTensor(Cell):
def __init__(self):
super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__()
self.const = Tensor(np.ones((6, 3, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*3*8).reshape((6, 3, 8)), mstype.float32), name="x")
def construct(self, index, value_0, value_1, value_2):
self.param[index] = (value_0, value_1, value_2)
ret = self.param + self.const
return ret
class TensorSetItemByTensorsWithNumber(Cell):
def __init__(self, value):
super(TensorSetItemByTensorsWithNumber, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.value = value
def construct(self, index_0, index_1, index_2):
self.param[index_0, index_1, index_2] = self.value
ret = self.param + self.const
return ret
class TensorSetItemByTensorsWithTensor(Cell):
def __init__(self):
super(TensorSetItemByTensorsWithTensor, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
def construct(self, index_0, index_1, index_2, value):
self.param[index_0, index_1, index_2] = value
ret = self.param + self.const
return ret
class TensorSetItemByTensorsWithTensorNumberError(Cell):
def __init__(self):
super(TensorSetItemByTensorsWithTensorNumberError, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
def construct(self, index_0, index_1, index_2, index_3, value):
self.param[index_0, index_1, index_2, index_3] = value
ret = self.param + self.const
return ret
class TensorSetItemByTensorsWithTupleOfNumber(Cell):
def __init__(self, value):
super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.value = value
def construct(self, index_0, index_1, index_2):
self.param[index_0, index_1, index_2] = self.value
ret = self.param + self.const
return ret
class TensorSetItemByTensorsWithTupleOfTensor(Cell):
def __init__(self):
super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
def construct(self, index_0, index_1, index_2, value_0, value_1, value_2):
self.param[index_0, index_1, index_2] = (value_0, value_1, value_2)
ret = self.param + self.const
return ret
class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell):
def __init__(self):
super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
def construct(self, index_0, index_1, index_2, value_0, value_1):
self.param[index_0, index_1, index_2] = (value_0, value_1)
ret = self.param + self.const
return ret
def test_tensor_assign(): def test_tensor_assign():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True) context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
net = TensorAssignWithSlice() net = TensorAssignWithSlice()
@ -441,15 +596,206 @@ test_cases = [
'block': NetWorkSliceEllipsis(), 'block': NetWorkSliceEllipsis(),
'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))], 'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))],
}), }),
('TensorIndexByOneTensor', {
'block': TensorIndexByOneTensor(),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(5, 4)), mstype.int32)],
}),
('TensorIndexByTwoTensors', {
'block': TensorIndexByTwoTensors(),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)],
}),
('TensorIndexByThreeTensors', {
'block': TensorIndexByThreeTensors(),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorSetItemByOneTensorWithNumber', {
'block': TensorSetItemByOneTensorWithNumber(value=0.0),
'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)],
}),
('TensorSetItemByOneTensorWithTensor', {
'block': TensorSetItemByOneTensorWithTensor(),
'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32),
Tensor(np.zeros((4, 7, 8)), mstype.float32)],
}),
('TensorSetItemByOneTensorWithTupleOfNumber', {
'block': TensorSetItemByOneTensorWithTupleOfNumber(value=(0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7)),
'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)],
}),
('TensorSetItemByOneTensorWithTupleOfTensor', {
'block': TensorSetItemByOneTensorWithTupleOfTensor(),
'desc_inputs': [Tensor(np.random.randint(6, size=(5, 4)), mstype.int32),
Tensor(np.zeros((8,), np.float32)),
Tensor(np.ones((8,), np.float32)),
Tensor(np.ones((8,), np.float32) * 2)],
}),
('TensorSetItemByTensorsWithNumber', {
'block': TensorSetItemByTensorsWithNumber(value=0.0),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorSetItemByTensorsWithTensor', {
'block': TensorSetItemByTensorsWithTensor(),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
Tensor(np.zeros((4, 5)), mstype.float32)],
}),
('TensorSetItemByTensorsWithTupleOfNumber', {
'block': TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.1, 2.2, 3.3, 4.4)),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorSetItemByTensorsWithTupleOfTensor', {
'block': TensorSetItemByTensorsWithTupleOfTensor(),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
Tensor(np.zeros((4, 5)), mstype.float32),
Tensor(np.ones((4, 5)), mstype.float32),
Tensor(np.ones((4, 5)) * 2, mstype.float32)],
})
]
raise_error_set = [
('TensorIndexByOneTensorDtypeError', {
'block': (TensorIndexByOneTensor(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)],
}),
('TensorIndexByTwoTensorsShapeError', {
'block': (TensorIndexByTwoTensors(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)],
}),
('TensorIndexByTwoTensorsDtypeError', {
'block': (TensorIndexByTwoTensors(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)],
}),
('TensorIndexByThreeTensorsShapeError', {
'block': (TensorIndexByThreeTensors(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)],
}),
('TensorIndexByThreeTensorsDtypeError', {
'block': (TensorIndexByThreeTensors(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorSetItemByOneTensorWithNumberTypeError', {
'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)],
}),
('TensorSetItemByOneTensorWithTensorShapeError', {
'block': (TensorSetItemByOneTensorWithTensor(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32),
Tensor(np.zeros((6, 7, 8)), mstype.float32)],
}),
('TensorSetItemByOneTensorWithTensorDtypeError', {
'block': (TensorSetItemByOneTensorWithTensor(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32),
Tensor(np.zeros((6, 7, 8)), mstype.int32)],
}),
('TensorSetItemByOneTensorWithTupleOfNumberTypeError', {
'block': (TensorSetItemByOneTensorWithTupleOfNumber(value=(0, 1, 2, 3, 4, 5, 6, 7)), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)],
}),
('TensorSetItemByOneTensorWithTupleOfNumberNumberError', {
'block': (TensorSetItemByOneTensorWithTupleOfNumber(value=(0.0, 1.1, 2.2)), {'exception': ValueError}),
'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)],
}),
('TensorSetItemByOneTensorWithTupleOfTensorDtyeError', {
'block': (TensorSetItemByOneTensorWithTupleOfTensor(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(5, 4)), mstype.int32),
Tensor(np.zeros((8,), np.int32)),
Tensor(np.ones((8,), np.int32)),
Tensor(np.ones((8,), np.float32) * 2)],
}),
('TensorSetItemByTensorsWithNumberTypeError', {
'block': (TensorSetItemByTensorsWithNumber(value=0), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorSetItemByTensorsWithTensorShapeError', {
'block': (TensorSetItemByTensorsWithTensor(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
Tensor(np.zeros((2, 5)), mstype.float32)],
}),
('TensorSetItemByTensorsWithTensorTypeError', {
'block': (TensorSetItemByTensorsWithTensor(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
Tensor(np.zeros((4, 5)), mstype.int32)],
}),
('TensorSetItemByTensorsWithTensorNumberError', {
'block': (TensorSetItemByTensorsWithTensorNumberError(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(1, 3, 4, 5)), mstype.int32),
Tensor(np.zeros((2, 5)), mstype.float32)],
}),
('TensorSetItemByTensorsWithTupleOfNumberTypeError', {
'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0, 1, 2, 3, 4)), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorSetItemByTensorsWithTupleOfNumberNumberError', {
'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.0, 2.0, 3.0)), {'exception': ValueError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorSetItemByTensorsWithTupleOfTensorNumberError', {
'block': (TensorSetItemByTensorsWithTupleOfTensorNumberError(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
Tensor(np.zeros((4, 5)), mstype.float32),
Tensor(np.ones((4, 5)), mstype.float32)],
}),
('TensorSetItemByTensorsWithTupleOfTensorTypeError', {
'block': (TensorSetItemByTensorsWithTupleOfTensor(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
Tensor(np.zeros((4, 5)), mstype.float32),
Tensor(np.ones((4, 5)), mstype.int32),
Tensor(np.ones((4, 5)) * 2, mstype.int32)],
})
] ]
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config) @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
def test_compile(): def test_exec():
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
return test_cases return test_cases
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
def test_check_exception():
return raise_error_set
def test_tensor_slice_reduce_out_of_bounds_neg(): def test_tensor_slice_reduce_out_of_bounds_neg():
class NetWork(Cell): class NetWork(Cell):
def __init__(self): def __init__(self):

View File

@ -26,7 +26,7 @@ from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops._grad.grad_base import bprop_getters from mindspore.ops._grad.grad_base import bprop_getters
from mindspore.ops._grad.grad_math_ops import binop_grad_common from mindspore.ops._grad.grad_math_ops import binop_grad_common
from mindspore.ops._utils import _get_broadcast_shape from mindspore.ops._utils import get_broadcast_shape
from mindspore.ops.primitive import PrimitiveWithInfer, prim_attr_register from mindspore.ops.primitive import PrimitiveWithInfer, prim_attr_register
from mindspore.train.loss_scale_manager import DynamicLossScaleManager from mindspore.train.loss_scale_manager import DynamicLossScaleManager
@ -54,7 +54,7 @@ class MockSub(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
def infer_shape(self, x_shape, y_shape): def infer_shape(self, x_shape, y_shape):
return _get_broadcast_shape(x_shape, y_shape) return get_broadcast_shape(x_shape, y_shape)
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
return x_dtype return x_dtype