forked from mindspore-Ecosystem/mindspore
!1133 support tensor get value by tensor index
Merge pull request !1133 from zhangbuxue/support_tensor_get_value_by_tensor_index
This commit is contained in:
commit
274f6f014f
|
@ -1172,6 +1172,12 @@ int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, co
|
|||
return 1;
|
||||
}
|
||||
|
||||
FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) {
|
||||
auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional");
|
||||
ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph));
|
||||
return ret_graph;
|
||||
}
|
||||
|
||||
FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||
// slice a tensor
|
||||
// args: tensor, slice or slice tuple
|
||||
|
@ -1229,12 +1235,6 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
|
|||
return ret_graph;
|
||||
}
|
||||
|
||||
FuncGraphPtr TensorSlice::ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const {
|
||||
auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional");
|
||||
ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph));
|
||||
return ret_graph;
|
||||
}
|
||||
|
||||
FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||
// select indexed item
|
||||
// args: tuple of items, index
|
||||
|
|
|
@ -206,8 +206,6 @@ class TensorSlice : public MetaFuncGraph {
|
|||
MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph)
|
||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||
friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; }
|
||||
|
||||
FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const;
|
||||
};
|
||||
using TensorSlicePtr = std::shared_ptr<TensorSlice>;
|
||||
|
||||
|
|
|
@ -101,6 +101,7 @@ const char kNameReLU6[] = "ReLU6";
|
|||
const char kNameReLU6Grad[] = "ReLU6Grad";
|
||||
const char kNameElu[] = "Elu";
|
||||
const char kNameEluGrad[] = "EluGrad";
|
||||
const char kNameScatterUpdate[] = "ScatterUpdate";
|
||||
const char kNameScatterNdUpdate[] = "ScatterNdUpdate";
|
||||
const char kNameScatterMax[] = "ScatterMax";
|
||||
const char kNameNMSWithMask[] = "NMSWithMask";
|
||||
|
@ -256,6 +257,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
|||
{string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)},
|
||||
{string(kNameZerosLike), ADPT_DESC(ZerosLike)},
|
||||
{string(kNameOnesLike), ADPT_DESC(OnesLike)},
|
||||
{string(kNameScatterUpdate), ADPT_DESC(ScatterUpdate)},
|
||||
{string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)},
|
||||
{string(kNameScatterMax), ADPT_DESC(ScatterMax)},
|
||||
{string(kNameNMSWithMask), ADPT_DESC(NMSWithMask)},
|
||||
|
|
|
@ -515,6 +515,11 @@ INPUT_MAP(Unpack) = {{1, INPUT_DESC(x)}};
|
|||
ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits<int>())}, {"num", ATTR_DESC(num, AnyTraits<int>())}};
|
||||
DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}};
|
||||
|
||||
// ScatterUpdate
|
||||
INPUT_MAP(ScatterUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
|
||||
ATTR_MAP(ScatterUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ScatterUpdate) = {{0, OUTPUT_DESC(var)}};
|
||||
|
||||
// ScatterNdUpdate
|
||||
INPUT_MAP(ScatterNdUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
|
||||
ATTR_MAP(ScatterNdUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
|
|
|
@ -132,6 +132,8 @@ DECLARE_OP_ADAPTER(ZerosLike)
|
|||
DECLARE_OP_USE_OUTPUT(ZerosLike)
|
||||
DECLARE_OP_ADAPTER(OnesLike)
|
||||
DECLARE_OP_USE_OUTPUT(OnesLike)
|
||||
DECLARE_OP_ADAPTER(ScatterUpdate)
|
||||
DECLARE_OP_USE_OUTPUT(ScatterUpdate)
|
||||
DECLARE_OP_ADAPTER(ScatterNdUpdate)
|
||||
DECLARE_OP_USE_OUTPUT(ScatterNdUpdate)
|
||||
DECLARE_OP_ADAPTER(ScatterMax)
|
||||
|
|
|
@ -179,14 +179,15 @@ from .bounding_box_encode import _bounding_box_encode_tbe
|
|||
from .check_valid import _check_valid_tbe
|
||||
from .iou import _iou_tbe
|
||||
from .arg_max import _arg_max_tbe
|
||||
from .nms_with_mask import nms_with_mask_op_info
|
||||
from .random_choice_with_mask import random_choice_with_mask_op_info
|
||||
from .sgd import sgd_op_info
|
||||
from .lars_update import lars_update_op_info
|
||||
from .nms_with_mask import _nms_with_mask_tbe
|
||||
from .random_choice_with_mask import _random_choice_with_mask_tbe
|
||||
from .sgd import _sgd_tbe
|
||||
from .lars_update import _lars_update_tbe
|
||||
from .bn_training_update_v2 import _bn_training_update_v2_tbe
|
||||
from .square_sum_all import square_sum_all_op_info
|
||||
from .square_sum_all import _square_sum_all_tbe
|
||||
from .pack import _pack_tbe
|
||||
from .unpack import _unpack_tbe
|
||||
from .scatter_update import _scatter_update_tbe
|
||||
from .prelu import _prelu_tbe
|
||||
from .prelu_grad import _prelu_grad_tbe
|
||||
from .binary_cross_entropy import _binary_cross_entropy_tbe
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ScatterUpdate op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
scatter_update_op_info = TBERegOp("ScatterUpdate") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("scatter_update.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("scatter_update") \
|
||||
.partial_flag(True) \
|
||||
.attr("use_locking", "optional", "bool", "all") \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
.input(1, "indices", False, "required", "all") \
|
||||
.input(1, "updates", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(scatter_update_op_info)
|
||||
def _scatter_update_tbe():
|
||||
"""ScatterUpdate TBE register"""
|
||||
return
|
|
@ -14,6 +14,6 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ops utils."""
|
||||
from .utils import _get_broadcast_shape, _get_concat_offset
|
||||
from .utils import get_broadcast_shape, get_concat_offset
|
||||
|
||||
__all__ = ['_get_broadcast_shape', '_get_concat_offset']
|
||||
__all__ = ['get_broadcast_shape', 'get_concat_offset']
|
||||
|
|
|
@ -19,7 +19,8 @@ from ..._checkparam import Validator as validator
|
|||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
|
||||
def _get_broadcast_shape(x_shape, y_shape, prim_name):
|
||||
|
||||
def get_broadcast_shape(x_shape, y_shape, prim_name):
|
||||
"""
|
||||
Doing broadcast between tensor x and tensor y.
|
||||
|
||||
|
@ -37,7 +38,7 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name):
|
|||
Examples:
|
||||
>>> x_shape = [1, 2, 3]
|
||||
>>> y_shape = [1, 2]
|
||||
>>> broadcast_shape = _get_broadcast_shape(x_shape, y_shape)
|
||||
>>> broadcast_shape = get_broadcast_shape(x_shape, y_shape)
|
||||
"""
|
||||
if x_shape == y_shape:
|
||||
return x_shape
|
||||
|
@ -54,15 +55,14 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name):
|
|||
elif x_shape[i] == y_shape[i]:
|
||||
broadcast_shape_back.append(x_shape[i])
|
||||
else:
|
||||
raise ValueError("For '{}' the x_shape {} and y_shape {} can not broadcast.".format(
|
||||
prim_name, x_shape, y_shape))
|
||||
raise ValueError(f"For '{prim_name}', the x_shape {x_shape} and y_shape {y_shape} can not broadcast.")
|
||||
|
||||
broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
|
||||
broadcast_shape = broadcast_shape_front + broadcast_shape_back
|
||||
broadcast_shape = list(broadcast_shape_front) + broadcast_shape_back
|
||||
return broadcast_shape
|
||||
|
||||
|
||||
def _get_concat_offset(x_shp, x_type, axis, prim_name):
|
||||
def get_concat_offset(x_shp, x_type, axis, prim_name):
|
||||
"""for concat and concatoffset check args and compute offset"""
|
||||
validator.check_value_type("shape", x_shp, [tuple], prim_name)
|
||||
validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name)
|
||||
|
@ -73,7 +73,7 @@ def _get_concat_offset(x_shp, x_type, axis, prim_name):
|
|||
if axis < 0:
|
||||
axis = axis + rank_base
|
||||
all_shp = x_shp[0][axis]
|
||||
offset = [0,]
|
||||
offset = [0]
|
||||
for i in range(1, len(x_shp)):
|
||||
v = x_shp[i]
|
||||
validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name)
|
||||
|
|
|
@ -1,226 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""constexpr util"""
|
||||
|
||||
from functools import reduce
|
||||
import numpy as np
|
||||
from ...primitive import constexpr
|
||||
from ....common.tensor import Tensor
|
||||
from ....common import dtype as mstype
|
||||
from ...._extends.utils import Slice, Ellipsis_
|
||||
|
||||
@constexpr
|
||||
def check_equal(param1, param2, msg="{},{}"):
|
||||
"""Checks whether the two parameters are equal or not."""
|
||||
if param1 != param2:
|
||||
raise ValueError(msg.format(param1, param2))
|
||||
return param1
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
|
||||
"""Checks the shape and size of the sensor and value."""
|
||||
if data_shape == value_shape or data_size == value_size or value_size == 1:
|
||||
return True
|
||||
raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(value_shape, data_shape))
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_tensor_setitem_index(index, element_type=None):
|
||||
"""Checks tuple index type of tensor assignment."""
|
||||
if index is None:
|
||||
raise IndexError("Tensor's index cannot be None.")
|
||||
# eg. Tensor[Slice] = u
|
||||
if isinstance(index, Slice):
|
||||
return True
|
||||
# eg. Tensor[tuple] = u
|
||||
if isinstance(index, tuple):
|
||||
if not index:
|
||||
raise IndexError("Tensor's index cannot be empty.")
|
||||
# eg. Tensor[tuple(Slice...)] = u
|
||||
if isinstance(index[0], (Slice, Ellipsis_, int)):
|
||||
return True
|
||||
raise IndexError("Index of type '{}' is not supported yet.".format(type(index[0])))
|
||||
# eg. Tensor[Tensor[dtype=bool]] = u
|
||||
if index == mstype.tensor:
|
||||
if element_type is None or element_type != mstype.bool_:
|
||||
raise TypeError(
|
||||
"The index of tensor should be a bool type tensor. "
|
||||
"{} type is not supported yet.".format(element_type))
|
||||
return True
|
||||
|
||||
raise IndexError("Index of type '{}' is not supported yet.".format(type(index)))
|
||||
|
||||
|
||||
@constexpr
|
||||
def is_same_type(inst, type_):
|
||||
"""
|
||||
Checks whether an object is an instance of a target type.
|
||||
|
||||
Inputs:
|
||||
inst (mindspore.dtype): Inspected type.
|
||||
type_ (mindspore.dtype): Target type.
|
||||
|
||||
Outputs:
|
||||
bool, the check result.
|
||||
"""
|
||||
return inst == type_
|
||||
|
||||
|
||||
def slice_expand(input_slices, shape):
|
||||
"""
|
||||
Converts slice to indices.
|
||||
|
||||
Inputs:
|
||||
slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
|
||||
shape (tuple): The shape of a sensor is an integer element tuple.
|
||||
|
||||
Outputs:
|
||||
tuple[list], This is expressed as (begins, ends, strides).
|
||||
"""
|
||||
begin = []
|
||||
end = []
|
||||
strides = []
|
||||
index = 0
|
||||
slices = None
|
||||
# Slice or tuple(Slice...)
|
||||
if isinstance(input_slices, Slice):
|
||||
slices = (input_slices,)
|
||||
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (Slice, Ellipsis_)):
|
||||
is_have_ellipsis = False
|
||||
for _, element in enumerate(input_slices):
|
||||
if isinstance(element, Ellipsis_):
|
||||
is_have_ellipsis = True
|
||||
break
|
||||
if is_have_ellipsis:
|
||||
slices = ellipsis2slice(input_slices, shape)
|
||||
else:
|
||||
slices = input_slices
|
||||
else:
|
||||
raise IndexError("Tensor's index type is not supported yet.")
|
||||
|
||||
for s in slices:
|
||||
start = 0 if (s.start is None) else s.start
|
||||
stop = shape[index] if (s.end is None) else s.end
|
||||
step = 1 if (s.step is None) else s.step
|
||||
begin.append(start)
|
||||
end.append(stop)
|
||||
strides.append(step)
|
||||
index += 1
|
||||
while index < len(shape):
|
||||
begin.append(0)
|
||||
end.append(shape[index])
|
||||
strides.append(1)
|
||||
index += 1
|
||||
return begin, end, strides
|
||||
|
||||
|
||||
def ellipsis2slice(input_, shape):
|
||||
"""Converts ellipsis to slice."""
|
||||
input_slice = input_
|
||||
result = []
|
||||
if isinstance(input_, Ellipsis_):
|
||||
input_slice = (input_,)
|
||||
ell_count = 0
|
||||
for _, element in enumerate(input_slice):
|
||||
if not isinstance(element, Ellipsis_):
|
||||
result.append(element)
|
||||
continue
|
||||
ell_count += 1
|
||||
if ell_count > 1:
|
||||
raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, "
|
||||
"but it is currently {}".format(input_slice))
|
||||
for _ in range(len(shape) - len(input_slice) + 1):
|
||||
result.append(Slice(None, None, None))
|
||||
return tuple(result)
|
||||
|
||||
|
||||
@constexpr
|
||||
def slice2indices(input_slices, shape):
|
||||
"""
|
||||
Converts slice to indices.
|
||||
|
||||
Inputs:
|
||||
slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
|
||||
shape (tuple): The shape of a tensor is an integer element tuple.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is (n, 1).
|
||||
"""
|
||||
begin, end, strides = slice_expand(input_slices, shape)
|
||||
np_r = []
|
||||
for i, element in enumerate(shape):
|
||||
s = begin[i] if (begin[i] >= 0) else (element + begin[i])
|
||||
e = end[i] if (end[i] >= 0) else (element + end[i])
|
||||
np_r.append(np.r_[s:e:strides[i]])
|
||||
# Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape)
|
||||
np_ix = np.ix_(*np_r)
|
||||
ravel = np.ravel_multi_index(np_ix, shape)
|
||||
ravel = Tensor(ravel.reshape(-1, 1), dtype=mstype.int32)
|
||||
return ravel
|
||||
|
||||
@constexpr
|
||||
def check_indices(indices_size, index):
|
||||
"""Checks indices whether is empty."""
|
||||
if indices_size < 1:
|
||||
raise IndexError("The tensor's index is unreasonable. index:{}".format(index))
|
||||
return indices_size
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_indices_value_size(indices_size, value_size):
|
||||
"""Checks if the sizes are already matched."""
|
||||
if value_size < 1:
|
||||
raise ValueError("The value assigned to tensor cannot be empty.")
|
||||
if value_size > 1:
|
||||
if value_size != indices_size:
|
||||
raise ValueError(
|
||||
"The value given to tensor does not match the index size,"
|
||||
" value size:{}, indics size:{}".format(value_size, indices_size))
|
||||
return value_size
|
||||
|
||||
@constexpr
|
||||
def integer_to_indices(index, shape):
|
||||
"""Converts int or tuple[int] to indices."""
|
||||
size = reduce(lambda x, y: x * y, shape)
|
||||
range_ = np.arange(size).reshape(shape)
|
||||
value = range_[index]
|
||||
value = value.reshape(-1, 1)
|
||||
return Tensor(value, dtype=mstype.int32)
|
||||
|
||||
@constexpr
|
||||
def tuple_element_is_slice(indexs):
|
||||
"""Judges tuple element type."""
|
||||
if not indexs:
|
||||
raise IndexError("Tensor's index cannot be empty.")
|
||||
if isinstance(indexs, tuple):
|
||||
for _, ele in enumerate(indexs):
|
||||
if not isinstance(ele, Slice):
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
@constexpr
|
||||
def tuple_element_is_int(indexs):
|
||||
"""Judges tuple element type."""
|
||||
if not indexs:
|
||||
raise IndexError("Tensor's index cannot be empty.")
|
||||
if isinstance(indexs, tuple):
|
||||
for _, ele in enumerate(indexs):
|
||||
if not isinstance(ele, int):
|
||||
return False
|
||||
return True
|
||||
return False
|
|
@ -0,0 +1,487 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""constexpr util"""
|
||||
from functools import reduce
|
||||
import numpy as np
|
||||
from ...primitive import constexpr
|
||||
from ....common.tensor import Tensor
|
||||
from ....common import dtype as mstype
|
||||
from ...._extends.utils import Slice, Ellipsis_
|
||||
from ....ops import _utils as op_utils
|
||||
from ...composite import base
|
||||
from .... import log as logger
|
||||
from ... import functional as F
|
||||
from ... import operations as P
|
||||
|
||||
hyper_map = base.HyperMap()
|
||||
pack = P.Pack(axis=-1)
|
||||
|
||||
ALL_TENSOR = 0
|
||||
NO_TENSOR = 1
|
||||
CONTAIN_TENSOR = 2
|
||||
ALL_SCALAR = 3
|
||||
|
||||
INT_ = 0
|
||||
BOOL_ = 1
|
||||
UNSUPPORTED_DTYPE = 2
|
||||
|
||||
TENSOR_SETITEM = "tensor setitem"
|
||||
TENSOR_GETITEM = "tensor getitem"
|
||||
|
||||
SET_ITEM_BY_ONE_TENSOR = 0
|
||||
SET_ITEM_BY_TUPLE_OF_TENSOR = 1
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_equal(param1, param2, msg="{},{}"):
|
||||
"""Checks whether the two parameters are equal or not."""
|
||||
if param1 != param2:
|
||||
raise ValueError(msg.format(param1, param2))
|
||||
return param1
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
|
||||
"""Checks the shape and size of the sensor and value."""
|
||||
if data_shape == value_shape or data_size == value_size or value_size == 1:
|
||||
return True
|
||||
raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(value_shape, data_shape))
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_tensor_setitem_index(index, element_type=None):
|
||||
"""Checks tuple index type of tensor assignment."""
|
||||
if index is None:
|
||||
raise IndexError("Tensor's index cannot be None.")
|
||||
# eg. Tensor[Slice] = u
|
||||
if isinstance(index, Slice):
|
||||
return True
|
||||
# eg. Tensor[tuple] = u
|
||||
if isinstance(index, tuple):
|
||||
if not index:
|
||||
raise IndexError("Tensor's index cannot be empty.")
|
||||
# eg. Tensor[tuple(Slice...)] = u
|
||||
if isinstance(index[0], (Slice, Ellipsis_, int)):
|
||||
return True
|
||||
raise IndexError("Index of type '{}' is not supported yet.".format(type(index[0])))
|
||||
# eg. Tensor[Tensor[dtype=bool]] = u
|
||||
if isinstance(index, mstype.tensor_type):
|
||||
if element_type is None or element_type != mstype.bool_:
|
||||
raise TypeError(
|
||||
"The index of tensor should be a bool type tensor. "
|
||||
"{} type is not supported yet.".format(element_type))
|
||||
return True
|
||||
|
||||
raise IndexError("Index of type '{}' is not supported yet.".format(type(index)))
|
||||
|
||||
|
||||
@constexpr
|
||||
def is_same_type(inst, type_):
|
||||
"""
|
||||
Checks whether an object is an instance of a target type.
|
||||
|
||||
Inputs:
|
||||
inst (mindspore.dtype): Inspected type.
|
||||
type_ (mindspore.dtype): Target type.
|
||||
|
||||
Outputs:
|
||||
bool, the check result.
|
||||
"""
|
||||
return inst == type_
|
||||
|
||||
|
||||
def slice_expand(input_slices, shape):
|
||||
"""
|
||||
Converts slice to indices.
|
||||
|
||||
Inputs:
|
||||
slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
|
||||
shape (tuple): The shape of a sensor is an integer element tuple.
|
||||
|
||||
Outputs:
|
||||
tuple[list], This is expressed as (begins, ends, strides).
|
||||
"""
|
||||
begin = []
|
||||
end = []
|
||||
strides = []
|
||||
index = 0
|
||||
slices = None
|
||||
# Slice or tuple(Slice...)
|
||||
if isinstance(input_slices, Slice):
|
||||
slices = (input_slices,)
|
||||
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (Slice, Ellipsis_)):
|
||||
is_have_ellipsis = False
|
||||
for _, element in enumerate(input_slices):
|
||||
if isinstance(element, Ellipsis_):
|
||||
is_have_ellipsis = True
|
||||
break
|
||||
if is_have_ellipsis:
|
||||
slices = ellipsis2slice(input_slices, shape)
|
||||
else:
|
||||
slices = input_slices
|
||||
else:
|
||||
raise IndexError("Tensor's index type is not supported yet.")
|
||||
|
||||
for s in slices:
|
||||
start = 0 if (s.start is None) else s.start
|
||||
stop = shape[index] if (s.end is None) else s.end
|
||||
step = 1 if (s.step is None) else s.step
|
||||
begin.append(start)
|
||||
end.append(stop)
|
||||
strides.append(step)
|
||||
index += 1
|
||||
while index < len(shape):
|
||||
begin.append(0)
|
||||
end.append(shape[index])
|
||||
strides.append(1)
|
||||
index += 1
|
||||
return begin, end, strides
|
||||
|
||||
|
||||
def ellipsis2slice(input_, shape):
|
||||
"""Converts ellipsis to slice."""
|
||||
input_slice = input_
|
||||
result = []
|
||||
if isinstance(input_, Ellipsis_):
|
||||
input_slice = (input_,)
|
||||
ell_count = 0
|
||||
for _, element in enumerate(input_slice):
|
||||
if not isinstance(element, Ellipsis_):
|
||||
result.append(element)
|
||||
continue
|
||||
ell_count += 1
|
||||
if ell_count > 1:
|
||||
raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, "
|
||||
"but it is currently {}".format(input_slice))
|
||||
for _ in range(len(shape) - len(input_slice) + 1):
|
||||
result.append(Slice(None, None, None))
|
||||
return tuple(result)
|
||||
|
||||
|
||||
@constexpr
|
||||
def slice2indices(input_slices, shape):
|
||||
"""
|
||||
Converts slice to indices.
|
||||
|
||||
Inputs:
|
||||
slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
|
||||
shape (tuple): The shape of a tensor is an integer element tuple.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is (n, 1).
|
||||
"""
|
||||
begin, end, strides = slice_expand(input_slices, shape)
|
||||
np_r = []
|
||||
for i, element in enumerate(shape):
|
||||
s = begin[i] if (begin[i] >= 0) else (element + begin[i])
|
||||
e = end[i] if (end[i] >= 0) else (element + end[i])
|
||||
np_r.append(np.r_[s:e:strides[i]])
|
||||
# Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape)
|
||||
np_ix = np.ix_(*np_r)
|
||||
ravel = np.ravel_multi_index(np_ix, shape)
|
||||
ravel = Tensor(ravel.reshape(-1, 1), dtype=mstype.int32)
|
||||
return ravel
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_indices(indices_size, index):
|
||||
"""Checks indices whether is empty."""
|
||||
if indices_size < 1:
|
||||
raise IndexError("The tensor's index is unreasonable. index:{}".format(index))
|
||||
return indices_size
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_indices_value_size(indices_size, value_size):
|
||||
"""Checks if the sizes are already matched."""
|
||||
if value_size < 1:
|
||||
raise ValueError("The value assigned to tensor cannot be empty.")
|
||||
if value_size > 1:
|
||||
if value_size != indices_size:
|
||||
raise ValueError(
|
||||
"The value given to tensor does not match the index size,"
|
||||
" value size:{}, indics size:{}".format(value_size, indices_size))
|
||||
return value_size
|
||||
|
||||
|
||||
@constexpr
|
||||
def integer_to_indices(index, shape):
|
||||
"""Converts int or tuple[int] to indices."""
|
||||
size = reduce(lambda x, y: x * y, shape)
|
||||
range_ = np.arange(size).reshape(shape)
|
||||
value = range_[index]
|
||||
value = value.reshape(-1, 1)
|
||||
return Tensor(value, dtype=mstype.int32)
|
||||
|
||||
|
||||
@constexpr
|
||||
def tuple_element_is_slice(indexs):
|
||||
"""Judges tuple element type."""
|
||||
if not indexs:
|
||||
raise IndexError("Tensor's index cannot be empty.")
|
||||
if isinstance(indexs, tuple):
|
||||
for _, ele in enumerate(indexs):
|
||||
if not isinstance(ele, Slice):
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@constexpr
|
||||
def tuple_element_is_int(indexs):
|
||||
"""Judges tuple element type."""
|
||||
if not indexs:
|
||||
raise IndexError("Tensor's index cannot be empty.")
|
||||
if isinstance(indexs, tuple):
|
||||
for _, ele in enumerate(indexs):
|
||||
if not isinstance(ele, int):
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@constexpr
|
||||
def tuple_elements_type(types):
|
||||
"""Judges the type of all elements of the tuple."""
|
||||
tensors_number = 0
|
||||
for ele in types:
|
||||
if isinstance(ele, mstype.tensor_type):
|
||||
tensors_number += 1
|
||||
if tensors_number == len(types):
|
||||
return ALL_TENSOR
|
||||
if tensors_number == 0:
|
||||
return NO_TENSOR
|
||||
return CONTAIN_TENSOR
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_value_elements(data_dtype, types):
|
||||
"""Judges the type of all elements of the tuple."""
|
||||
tensors_number = 0
|
||||
scalars_number = 0
|
||||
for i, ele in enumerate(types):
|
||||
if isinstance(ele, mstype.tensor_type):
|
||||
ele_dtype = ele.element_type()
|
||||
if data_dtype == ele_dtype:
|
||||
tensors_number += 1
|
||||
else:
|
||||
raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' "
|
||||
f"in value tuple is not consistent with origin tensor data type '{data_dtype}'.")
|
||||
elif mstype.issubclass_(ele, data_dtype):
|
||||
scalars_number += 1
|
||||
else:
|
||||
raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in "
|
||||
f"value tuple is not consistent with origin tensor data type '{data_dtype}'.")
|
||||
if tensors_number == len(types):
|
||||
return ALL_TENSOR
|
||||
if scalars_number == len(types):
|
||||
return ALL_SCALAR
|
||||
raise TypeError(f"For '{TENSOR_SETITEM}', the value does not support scalar and tensor mixing, but got {types}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def get_index_tensor_dtype(dtype):
|
||||
"""Check a tuple of tensor data type."""
|
||||
if dtype == mstype.int32:
|
||||
return INT_
|
||||
if dtype == mstype.bool_:
|
||||
return BOOL_
|
||||
raise TypeError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_index_tensors_dtype(dtypes, op_name):
|
||||
"""Check a tuple of tensor data type."""
|
||||
if op_name == TENSOR_GETITEM:
|
||||
valid_dtypes = (mstype.int32, mstype.int64)
|
||||
elif op_name == TENSOR_SETITEM:
|
||||
valid_dtypes = (mstype.int32,)
|
||||
else:
|
||||
raise ValueError("Unsupported operation.")
|
||||
for ele in dtypes:
|
||||
if ele in valid_dtypes and ele == dtypes[0]:
|
||||
continue
|
||||
raise TypeError(f"For '{op_name}', the index tensors data type must be same, "
|
||||
f"and should be one of the following: {valid_dtypes}, but got {dtypes}.")
|
||||
return True
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_tensor_dtype_valid(dtype, valid_dtypes):
|
||||
"""Check a tensor data type."""
|
||||
if dtype in valid_dtypes:
|
||||
return True
|
||||
raise TypeError(f"The index tensor data type must be one of "
|
||||
f"the following: {valid_dtypes}, but got {dtype}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_tensors_dtype_same(x_dtype, y_dtype, op_name):
|
||||
"""Check tensors data type same."""
|
||||
if x_dtype == y_dtype:
|
||||
return True
|
||||
raise TypeError(f"For '{op_name}', the value data type '{y_dtype}' "
|
||||
f"is not consistent with origin tensor data type {x_dtype}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def broadcast_shapes(shapes, op_name):
|
||||
"""Broadcasts a tuple of tensor."""
|
||||
broadcast_shape = shapes[0]
|
||||
for i, shape in enumerate(shapes):
|
||||
logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.")
|
||||
broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name)
|
||||
return tuple(broadcast_shape)
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_two_shapes_need_broadcast(shape_x, shape_y):
|
||||
"""Check two shapes need broadcast."""
|
||||
error = ValueError(f"For 'tensor setitem with tensor', the value tensor shape "
|
||||
f"{shape_y} could not broadcast the required updates shape {shape_x}.")
|
||||
if len(shape_y) > len(shape_x):
|
||||
raise error
|
||||
for i in range(-len(shape_y), 0):
|
||||
if shape_y[i] > shape_x[i]:
|
||||
raise error
|
||||
if shape_y[i] < shape_x[i] and shape_y[i] != 1:
|
||||
raise error
|
||||
if shape_y == shape_x:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@constexpr
|
||||
def compute_multiples(origin_shape, broadcast_shape):
|
||||
"""Compute multiples between broadcast_shape with origin_shape."""
|
||||
len_gap = len(broadcast_shape) - len(origin_shape)
|
||||
return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape))
|
||||
|
||||
|
||||
def tile(broadcast_shape, x):
|
||||
multiples = compute_multiples(F.shape(x), broadcast_shape)
|
||||
return F.tile(x, multiples)
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_shapes_same(value_shapes, op_name):
|
||||
"""Check if the shapes in the tuple are consistent."""
|
||||
for i, shape in enumerate(value_shapes):
|
||||
if shape != value_shapes[0]:
|
||||
raise ValueError(f"For '{op_name}', the {i}th tensor shape in value tuple "
|
||||
f"is not same as the first tensor shape.")
|
||||
return True
|
||||
|
||||
|
||||
@constexpr
|
||||
def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type):
|
||||
"""Convert a scalar to a tensor."""
|
||||
if op_type == SET_ITEM_BY_ONE_TENSOR:
|
||||
updates_shape = indices_shape + data_shape[1:]
|
||||
else:
|
||||
updates_shape = indices_shape[:-1] + data_shape[indices_shape[-1]:]
|
||||
if isinstance(value, mstype.dtype_to_pytype(data_dtype)):
|
||||
return Tensor(np.full(updates_shape, value), dtype=data_dtype)
|
||||
raise TypeError(f"For '{TENSOR_SETITEM}', the value type '{value.__class__.__name__}'"
|
||||
f" is not consistent with tensor data type {data_dtype}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type):
|
||||
"""Convert a tuple of scalar to a tensor."""
|
||||
updates_shape = generate_updates_shape(data_shape, index_shape, op_type)
|
||||
if len(value) != updates_shape[-1]:
|
||||
raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} in the updates tuple "
|
||||
f"does not meet the requirements: {updates_shape[-1]}.")
|
||||
array = np.array(value, dtype=mstype.dtype_to_nptype(data_dtype))
|
||||
reps = compute_multiples(updates_shape[-1:], updates_shape)
|
||||
return Tensor(np.tile(array, reps))
|
||||
|
||||
|
||||
@constexpr
|
||||
def generate_updates_shape(data_shape, index_shape, op_type):
|
||||
"""Generate updates shape for 'tensor setitem'."""
|
||||
if op_type == SET_ITEM_BY_ONE_TENSOR:
|
||||
updates_shape = index_shape + data_shape[1:]
|
||||
else:
|
||||
updates_shape = index_shape[:-1] + data_shape[index_shape[-1]:]
|
||||
return updates_shape
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_number_of_index_tensor(data_shape, tuple_len, op_name):
|
||||
"""Check if the number of index tensor exceeds the dimension of the operated tensor."""
|
||||
if tuple_len <= len(data_shape):
|
||||
return True
|
||||
raise IndexError(f"For '{op_name}', the number {tuple_len} of index tensor "
|
||||
f"is greater than the dimension {len(data_shape)} of the operated tensor.")
|
||||
|
||||
|
||||
def generate_indeices_from_tuple_of_tensor(data, tuple_index, op_name):
|
||||
"""Generate an indices tensor from a tuple of tensor."""
|
||||
indices = None
|
||||
check_index_tensor_number = check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name)
|
||||
if check_index_tensor_number:
|
||||
dtype_tuple = hyper_map(F.dtype, tuple_index)
|
||||
check_dtypes = check_index_tensors_dtype(dtype_tuple, op_name)
|
||||
if check_dtypes:
|
||||
shape_tuple = hyper_map(F.shape, tuple_index)
|
||||
broadcast_shape = broadcast_shapes(shape_tuple, op_name)
|
||||
broadcast_tensors = hyper_map(F.partial(tile, broadcast_shape), tuple_index)
|
||||
indices = pack(broadcast_tensors)
|
||||
return indices
|
||||
|
||||
|
||||
def generate_updates_from_scalar(data, indices, value, op_type):
|
||||
"""Generate an updates tensor from a scalar."""
|
||||
data_shape = F.shape(data)
|
||||
indices_shape = F.shape(indices)
|
||||
data_dtype = F.dtype(data)
|
||||
return convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type)
|
||||
|
||||
|
||||
def generate_updates_from_tuple(data, index, value, op_type):
|
||||
"""Generate an updates tensor from a tuple."""
|
||||
value_types = hyper_map(F.typeof, value)
|
||||
data_dtype = F.dtype(data)
|
||||
value_elements_type = check_value_elements(data_dtype, value_types)
|
||||
if value_elements_type == ALL_TENSOR:
|
||||
value_shapes = hyper_map(F.shape, value)
|
||||
shapes_same = check_shapes_same(value_shapes, TENSOR_SETITEM)
|
||||
if shapes_same:
|
||||
value = F.pack(value)
|
||||
return generate_updates_from_tensor(data, index, value, op_type)
|
||||
|
||||
data_shape = F.shape(data)
|
||||
index_shape = F.shape(index)
|
||||
return convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type)
|
||||
|
||||
|
||||
def generate_updates_from_tensor(data, index, value, op_type):
|
||||
"""Generate an updates tensor from a tensor."""
|
||||
data_shape = F.shape(data)
|
||||
index_shape = F.shape(index)
|
||||
value_shape = F.shape(value)
|
||||
data_dtype = F.dtype(data)
|
||||
value_dtype = F.dtype(value)
|
||||
updates_shape = value_shape
|
||||
check_dtype_same = check_tensors_dtype_same(data_dtype, value_dtype, TENSOR_SETITEM)
|
||||
if check_dtype_same:
|
||||
updates_shape = generate_updates_shape(data_shape, index_shape, op_type)
|
||||
need_broadcast = check_two_shapes_need_broadcast(updates_shape, value_shape)
|
||||
if need_broadcast:
|
||||
return tile(updates_shape, value)
|
||||
return value
|
|
@ -15,9 +15,10 @@
|
|||
|
||||
"""Implementation for getitem."""
|
||||
|
||||
from ...composite import base
|
||||
from . import _utils as multi_utils
|
||||
from ..import base
|
||||
from ... import functional as F
|
||||
|
||||
from ....common import dtype as mstype
|
||||
|
||||
getitem = base.MultitypeFuncGraph('getitem')
|
||||
"""
|
||||
|
@ -214,19 +215,45 @@ def _tensor_getitem_by_slice(data, slice_index):
|
|||
return _tensor_slice(data, slice_index)
|
||||
|
||||
|
||||
@getitem.register("Tensor", "Tensor")
|
||||
def _tensor_getitem_by_tensor(data, tensor_index):
|
||||
"""
|
||||
Getting item of tensor by slice.
|
||||
|
||||
Inputs:
|
||||
data (Tensor): A tensor.
|
||||
tensor_index (Tensor): An index expressed by tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, element type is same as the element type of data.
|
||||
"""
|
||||
check_dtypes = multi_utils.check_tensor_dtype_valid(F.dtype(tensor_index), (mstype.int32, mstype.int64))
|
||||
result = None
|
||||
if check_dtypes:
|
||||
result = F.gather(data, tensor_index, 0)
|
||||
return result
|
||||
|
||||
|
||||
@getitem.register("Tensor", "Tuple")
|
||||
def _tensor_getitem_by_slice_tuple(data, slice_tuple_index):
|
||||
def _tensor_getitem_by_tuple(data, tuple_index):
|
||||
"""
|
||||
Getting item of tensor by slice tuple.
|
||||
|
||||
Inputs:
|
||||
data (Tensor): A tensor.
|
||||
slice_tuple_index (tuple): Index in tuple.
|
||||
tuple_index (tuple): Index in tuple.
|
||||
|
||||
Outputs:
|
||||
Tensor, element type is same as the element type of data.
|
||||
"""
|
||||
return _tensor_slice(data, slice_tuple_index)
|
||||
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = multi_utils.tuple_elements_type(index_types)
|
||||
result = None
|
||||
if index_elements_type == multi_utils.NO_TENSOR:
|
||||
result = _tensor_slice(data, tuple_index)
|
||||
if index_elements_type == multi_utils.ALL_TENSOR:
|
||||
result = _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
|
||||
return result
|
||||
|
||||
|
||||
@getitem.register("Tensor", "Ellipsis")
|
||||
|
@ -242,3 +269,10 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index):
|
|||
Tensor, same as data.
|
||||
"""
|
||||
return _tensor_slice(data, ellipsis_index)
|
||||
|
||||
|
||||
def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
|
||||
"""Tensor getitem by a tuple of tensor."""
|
||||
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_GETITEM)
|
||||
result = F.gather_nd(data, indices)
|
||||
return result
|
||||
|
|
|
@ -18,10 +18,11 @@
|
|||
from ...composite import base
|
||||
from ....common import dtype as mstype
|
||||
from ... import functional as F
|
||||
from . import _multitype_ops_util as mult_util
|
||||
from . import _utils as multi_utils
|
||||
|
||||
setitem = base.MultitypeFuncGraph('setitem')
|
||||
|
||||
|
||||
@setitem.register("List", "Number", "String")
|
||||
def _list_setitem_with_string(data, number_index, value):
|
||||
"""
|
||||
|
@ -118,7 +119,7 @@ def _dict_setitem_with_number(data, key, value):
|
|||
|
||||
|
||||
@setitem.register("Tensor", "Tensor", "Tensor")
|
||||
def _tensor_setitem_by_tensor_v1(data, index, value_tensor):
|
||||
def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
|
||||
"""
|
||||
Tensor assignment.
|
||||
|
||||
|
@ -137,27 +138,15 @@ def _tensor_setitem_by_tensor_v1(data, index, value_tensor):
|
|||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
result = None
|
||||
index_dtype = F.dtype(index)
|
||||
index_shape = F.shape(index)
|
||||
check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype)
|
||||
if check_result:
|
||||
data_shape = F.shape(data)
|
||||
data_shape = mult_util.check_equal(data_shape, index_shape,
|
||||
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
|
||||
size = F.size(value_tensor)
|
||||
size = mult_util.check_equal(1, size,
|
||||
"When assign value is a tensor, its size should be {}, but current size is {}.")
|
||||
dtype = F.dtype(data)
|
||||
u_cast = F.cast(value_tensor, dtype)
|
||||
one_data = F.ones_like(data)
|
||||
u = F.tensor_mul(one_data, u_cast)
|
||||
result = F.select(index, u, data)
|
||||
return result
|
||||
tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype)
|
||||
if tensor_dtype == multi_utils.INT_:
|
||||
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
|
||||
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Tensor", "Number")
|
||||
def _tensor_setitem_by_tensor_v2(data, index, value):
|
||||
def _tensor_setitem_by_tensor_with_number(data, index, value):
|
||||
"""
|
||||
Tensor assignment.
|
||||
|
||||
|
@ -171,22 +160,128 @@ def _tensor_setitem_by_tensor_v2(data, index, value):
|
|||
Inputs:
|
||||
data (Tensor): Assigned tensor.
|
||||
index (Tensor): Tensor of bool type.
|
||||
value_tensor (Number): Assignment value.
|
||||
value (Number): Assignment value.
|
||||
|
||||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
result = None
|
||||
index_dtype = F.dtype(index)
|
||||
index_shape = F.shape(index)
|
||||
check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype)
|
||||
if check_result:
|
||||
shape = F.shape(data)
|
||||
shape = mult_util.check_equal(
|
||||
shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
|
||||
dtype = F.dtype(data)
|
||||
u = F.fill(dtype, shape, value)
|
||||
result = F.select(index, u, data)
|
||||
tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype)
|
||||
if tensor_dtype == multi_utils.BOOL_:
|
||||
return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value)
|
||||
return _tensor_setitem_by_int_tensor_with_scalar(data, index, value)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Tuple", "Number")
|
||||
def _tensor_setitem_by_tuple_with_number(data, tuple_index, value):
|
||||
"""
|
||||
Tensor assignment.
|
||||
|
||||
Note:
|
||||
Syntax support: A[B, C, D] = u.
|
||||
Restraint condition: 1) A is a Tensor, and B, C, D are index.
|
||||
2) u is a scalar.
|
||||
|
||||
Inputs:
|
||||
data (Tensor): Assigned tensor.
|
||||
index (Tuple): An index tuple.
|
||||
value (Number): Assignment value.
|
||||
|
||||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = multi_utils.tuple_elements_type(index_types)
|
||||
result = None
|
||||
if index_elements_type == multi_utils.NO_TENSOR:
|
||||
result = _tensor_assgin_number(data, tuple_index, value)
|
||||
if index_elements_type == multi_utils.ALL_TENSOR:
|
||||
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
|
||||
updates = multi_utils.generate_updates_from_scalar(data, indices, value,
|
||||
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
||||
result = F.scatter_nd_update(data, indices, updates)
|
||||
return result
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Tuple", "Tensor")
|
||||
def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
||||
"""
|
||||
Tensor assignment.
|
||||
|
||||
Note:
|
||||
Syntax support: A[B, C, D] = U.
|
||||
Restraint condition: 1) A is a Tensor, and B, C, D are index Tensors.
|
||||
2) U is a Tensor.
|
||||
|
||||
Inputs:
|
||||
data (Tensor): Assigned tensor.
|
||||
index (Tuple): An index tuple.
|
||||
value (Tensor): Assignment tensor, should has the same data type as 'data'.
|
||||
|
||||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = multi_utils.tuple_elements_type(index_types)
|
||||
result = None
|
||||
if index_elements_type == multi_utils.NO_TENSOR:
|
||||
result = _tensor_assgin_tensor(data, tuple_index, value)
|
||||
if index_elements_type == multi_utils.ALL_TENSOR:
|
||||
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
|
||||
updates = multi_utils.generate_updates_from_tensor(data, indices, value,
|
||||
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
||||
result = F.scatter_nd_update(data, indices, updates)
|
||||
return result
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Tuple", "Tuple")
|
||||
def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
|
||||
"""
|
||||
Tensor assignment.
|
||||
|
||||
Note:
|
||||
Syntax support: A[B, C, D] = U.
|
||||
Restraint condition: 1) A is a Tensor, and B, C, D are index Tensors.
|
||||
2) A B and C could be broadcast.
|
||||
3) U is a Tensor.
|
||||
|
||||
Inputs:
|
||||
data (Tensor): Assigned tensor.
|
||||
index (Tuple): A tuple of tensor, these tensor could be broadcast.
|
||||
value (Tensor): Assignment tensor, should has the same data type as 'data'.
|
||||
|
||||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = multi_utils.tuple_elements_type(index_types)
|
||||
result = None
|
||||
if index_elements_type == multi_utils.ALL_TENSOR:
|
||||
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
|
||||
updates = multi_utils.generate_updates_from_tuple(data, indices, value,
|
||||
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
||||
result = F.scatter_nd_update(data, indices, updates)
|
||||
return result
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Tensor", "Tuple")
|
||||
def _tensor_setitem_by_tensor_v2(data, index, value):
|
||||
"""
|
||||
Tensor assignment.
|
||||
|
||||
Inputs:
|
||||
data (Tensor): Assigned tensor.
|
||||
index (Tensor): Tensor of bool type.
|
||||
value (Tuple): Assignment value.
|
||||
|
||||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
index_dtype = F.dtype(index)
|
||||
check_dtype = multi_utils.check_tensor_dtype_valid(index_dtype, (mstype.int32, mstype.int64))
|
||||
result = None
|
||||
if check_dtype:
|
||||
result = _tensor_setitem_by_tensor_with_tuple(data, index, value)
|
||||
return result
|
||||
|
||||
|
||||
|
@ -212,66 +307,6 @@ def _tensor_setitem_with_slice_v3(data, input_slice, value):
|
|||
return _tensor_assgin_tensor(data, input_slice, value)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Tuple", "Tensor")
|
||||
def _tensor_setitem_with_slice_v4(data, input_slice, value):
|
||||
"""
|
||||
Tensor assignment.
|
||||
|
||||
Note:
|
||||
Syntax support: A[tuple(Slice)] = U, and A[tuple(Number)] = U
|
||||
Restraint condition: A is a Tensor
|
||||
Slice like "1:3, ::, :4:-1"
|
||||
U is a Tensor(size=1) or Tensor(size>1)
|
||||
|
||||
Inputs:
|
||||
data (Tensor): Assigned tensor.
|
||||
input_slice (Union[tuple[Slice], tuple[Number]]): Slice expression.
|
||||
value (Number): Assignment value.
|
||||
|
||||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
return _tensor_assgin_tensor(data, input_slice, value)
|
||||
|
||||
|
||||
def _tensor_assgin_tensor(data, input_slice, value):
|
||||
"""Assigns a tensor value to the tensor by slice."""
|
||||
result = None
|
||||
check_result = mult_util.check_tensor_setitem_index(input_slice)
|
||||
if check_result:
|
||||
data_shape = F.shape(data)
|
||||
indices = mult_util.slice2indices(input_slice, data_shape)
|
||||
is_tuple_int = mult_util.tuple_element_is_int(input_slice)
|
||||
if is_tuple_int:
|
||||
indices = mult_util.integer_to_indices(input_slice, data_shape)
|
||||
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
|
||||
return result
|
||||
|
||||
|
||||
def _tensor_indices_tensor(data, data_shape, index, indices, value):
|
||||
"""Assigns a tensor value to the tensor."""
|
||||
data_size = F.size(data)
|
||||
data_dtype = F.dtype(data)
|
||||
indices_size = F.size(indices)
|
||||
indices_size = mult_util.check_indices(indices_size, index)
|
||||
update = F.fill(mstype.int32, (indices_size,), 1)
|
||||
condition_1d = F.scatter_nd(indices, update, (data_size,))
|
||||
condition = F.reshape(condition_1d, data_shape)
|
||||
condition = F.cast(condition, mstype.bool_)
|
||||
value_fill = None
|
||||
value_size = F.size(value)
|
||||
|
||||
value_size = mult_util.check_indices_value_size(indices_size, value_size)
|
||||
if value_size == 1:
|
||||
value_fill = F.fill(data_dtype, (indices_size,), 1)
|
||||
value = F.cast(value, data_dtype)
|
||||
value_fill = F.tensor_mul(value_fill, value)
|
||||
elif value_size > 1:
|
||||
value_fill = F.reshape(value, (indices_size,))
|
||||
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
|
||||
u = F.reshape(value_1d, data_shape)
|
||||
return F.select(condition, u, data)
|
||||
|
||||
@setitem.register("Tensor", "Slice", "Number")
|
||||
def _tensor_setitem_with_slice_v1(data, input_slice, value):
|
||||
"""
|
||||
|
@ -294,63 +329,25 @@ def _tensor_setitem_with_slice_v1(data, input_slice, value):
|
|||
return _tensor_assgin_number(data, input_slice, value)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Tuple", "Number")
|
||||
def _tensor_setitem_with_slice_v2(data, input_slice, value):
|
||||
"""
|
||||
Tensor assignment.
|
||||
|
||||
Note:
|
||||
Syntax support: A[tuple(Slice)] = u, and A[tuple(Number)] = u
|
||||
Restraint condition: A is a Tensor.
|
||||
Slice like "1:3, ::, :4:-1"
|
||||
u is a scalar
|
||||
|
||||
Inputs:
|
||||
data (Tensor): Assigned tensor.
|
||||
input_slice (Union[tuple[Slice], tuple[Number]]): slice expression.
|
||||
value (Number): Assignment value.
|
||||
|
||||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
return _tensor_assgin_number(data, input_slice, value)
|
||||
|
||||
|
||||
def _tensor_assgin_number(data, input_slice, value):
|
||||
"""Givens a scalar assign to tensor by slice"""
|
||||
check_result = mult_util.check_tensor_setitem_index(input_slice)
|
||||
check_result = multi_utils.check_tensor_setitem_index(input_slice)
|
||||
result = None
|
||||
if check_result:
|
||||
data_shape = F.shape(data)
|
||||
indices = mult_util.slice2indices(input_slice, data_shape)
|
||||
is_tuple_int = mult_util.tuple_element_is_int(input_slice)
|
||||
indices = multi_utils.slice2indices(input_slice, data_shape)
|
||||
is_tuple_int = multi_utils.tuple_element_is_int(input_slice)
|
||||
if is_tuple_int:
|
||||
indices = mult_util.integer_to_indices(input_slice, data_shape)
|
||||
indices = multi_utils.integer_to_indices(input_slice, data_shape)
|
||||
result = _tensor_indices_number(data, data_shape, input_slice, indices, value)
|
||||
return result
|
||||
|
||||
|
||||
def _tensor_indices_number(data, data_shape, index, indices, value):
|
||||
"""Assigns a scalar value to the tensor."""
|
||||
data_size = F.size(data)
|
||||
data_dtype = F.dtype(data)
|
||||
indices_size = F.size(indices)
|
||||
indices_size = mult_util.check_indices(indices_size, index)
|
||||
update = F.fill(mstype.int32, (indices_size,), 1)
|
||||
condition_1d = F.scatter_nd(indices, update, (data_size,))
|
||||
condition = F.reshape(condition_1d, data_shape)
|
||||
condition = F.cast(condition, mstype.bool_)
|
||||
value_fill = F.fill(data_dtype, (indices_size,), value)
|
||||
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
|
||||
u = F.reshape(value_1d, data_shape)
|
||||
return F.select(condition, u, data)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Number", "Number")
|
||||
def _tensor_setitem_with_int_v1(data, index, value):
|
||||
"""Syntax: A[1] = 3"""
|
||||
data_shape = F.shape(data)
|
||||
indices = mult_util.integer_to_indices(index, data_shape)
|
||||
indices = multi_utils.integer_to_indices(index, data_shape)
|
||||
return _tensor_indices_number(data, data_shape, index, indices, value)
|
||||
|
||||
|
||||
|
@ -358,7 +355,7 @@ def _tensor_setitem_with_int_v1(data, index, value):
|
|||
def _tensor_setitem_with_int_v2(data, index, value):
|
||||
"""Syntax: A[1] = Tensor"""
|
||||
data_shape = F.shape(data)
|
||||
indices = mult_util.integer_to_indices(index, data_shape)
|
||||
indices = multi_utils.integer_to_indices(index, data_shape)
|
||||
return _tensor_indices_tensor(data, data_shape, index, indices, value)
|
||||
|
||||
|
||||
|
@ -379,7 +376,7 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value):
|
|||
data_size = F.size(data)
|
||||
value_shape = F.shape(value)
|
||||
value_size = F.size(value)
|
||||
check_result = mult_util.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
|
||||
check_result = multi_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
|
||||
if check_result:
|
||||
if data_size == value_size:
|
||||
result = F.reshape(value, data_shape)
|
||||
|
@ -389,3 +386,108 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value):
|
|||
param2 = F.cast(value, data_dtype)
|
||||
result = F.tensor_mul(param1, param2)
|
||||
return result
|
||||
|
||||
|
||||
def _tensor_assgin_tensor(data, input_slice, value):
|
||||
"""Assigns a tensor value to the tensor by slice."""
|
||||
result = None
|
||||
check_result = multi_utils.check_tensor_setitem_index(input_slice)
|
||||
if check_result:
|
||||
data_shape = F.shape(data)
|
||||
indices = multi_utils.slice2indices(input_slice, data_shape)
|
||||
is_tuple_int = multi_utils.tuple_element_is_int(input_slice)
|
||||
if is_tuple_int:
|
||||
indices = multi_utils.integer_to_indices(input_slice, data_shape)
|
||||
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
|
||||
return result
|
||||
|
||||
|
||||
def _tensor_indices_tensor(data, data_shape, index, indices, value):
|
||||
"""Assigns a tensor value to the tensor."""
|
||||
data_size = F.size(data)
|
||||
data_dtype = F.dtype(data)
|
||||
indices_size = F.size(indices)
|
||||
indices_size = multi_utils.check_indices(indices_size, index)
|
||||
update = F.fill(mstype.int32, (indices_size,), 1)
|
||||
condition_1d = F.scatter_nd(indices, update, (data_size,))
|
||||
condition = F.reshape(condition_1d, data_shape)
|
||||
condition = F.cast(condition, mstype.bool_)
|
||||
value_fill = None
|
||||
value_size = F.size(value)
|
||||
|
||||
value_size = multi_utils.check_indices_value_size(indices_size, value_size)
|
||||
if value_size == 1:
|
||||
value_fill = F.fill(data_dtype, (indices_size,), 1)
|
||||
value = F.cast(value, data_dtype)
|
||||
value_fill = F.tensor_mul(value_fill, value)
|
||||
elif value_size > 1:
|
||||
value_fill = F.reshape(value, (indices_size,))
|
||||
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
|
||||
u = F.reshape(value_1d, data_shape)
|
||||
return F.select(condition, u, data)
|
||||
|
||||
|
||||
def _tensor_indices_number(data, data_shape, index, indices, value):
|
||||
"""Assigns a scalar value to the tensor."""
|
||||
data_size = F.size(data)
|
||||
data_dtype = F.dtype(data)
|
||||
indices_size = F.size(indices)
|
||||
indices_size = multi_utils.check_indices(indices_size, index)
|
||||
update = F.fill(mstype.int32, (indices_size,), 1)
|
||||
condition_1d = F.scatter_nd(indices, update, (data_size,))
|
||||
condition = F.reshape(condition_1d, data_shape)
|
||||
condition = F.cast(condition, mstype.bool_)
|
||||
value_fill = F.fill(data_dtype, (indices_size,), value)
|
||||
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
|
||||
u = F.reshape(value_1d, data_shape)
|
||||
return F.select(condition, u, data)
|
||||
|
||||
|
||||
def _tensor_setitem_by_tensor_with_tuple(data, index, value):
|
||||
"""Set a tensor item by a tensor with a tuple."""
|
||||
updates = multi_utils.generate_updates_from_tuple(data, index, value,
|
||||
multi_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||
result = F.scatter_update(data, index, updates)
|
||||
return result
|
||||
|
||||
|
||||
def _tensor_setitem_by_int_tensor_with_scalar(data, index, value):
|
||||
"""Set a tensor item by a int tensor with a scalar."""
|
||||
updates = multi_utils.generate_updates_from_scalar(data, index, value,
|
||||
multi_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||
return F.scatter_update(data, index, updates)
|
||||
|
||||
|
||||
def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):
|
||||
"""Set a tensor item by a bool tensor with a scalar."""
|
||||
index_shape = F.shape(index)
|
||||
shape = F.shape(data)
|
||||
shape = multi_utils.check_equal(
|
||||
shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
|
||||
dtype = F.dtype(data)
|
||||
u = F.fill(dtype, shape, value)
|
||||
return F.select(index, u, data)
|
||||
|
||||
|
||||
def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
|
||||
"""Set a tensor item by a int tensor with a tensor."""
|
||||
updates = multi_utils.generate_updates_from_tensor(data, index, value,
|
||||
multi_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||
return F.scatter_update(data, index, updates)
|
||||
|
||||
|
||||
def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
|
||||
"""Set a tensor item by a bool tensor with a tensor."""
|
||||
index_shape = F.shape(index)
|
||||
data_shape = F.shape(data)
|
||||
data_shape = multi_utils.check_equal(data_shape, index_shape,
|
||||
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
|
||||
size = F.size(value)
|
||||
size = multi_utils.check_equal(1, size,
|
||||
"When assign value is a tensor, its size should be {}, but current size is {}.")
|
||||
dtype = F.dtype(data)
|
||||
u_cast = F.cast(value, dtype)
|
||||
one_data = F.ones_like(data)
|
||||
u = F.tensor_mul(one_data, u_cast)
|
||||
result = F.select(index, u, data)
|
||||
return result
|
||||
|
|
|
@ -31,6 +31,7 @@ dtype = P.DType()
|
|||
issubclass_ = P.IsSubClass()
|
||||
isinstance_ = P.IsInstance()
|
||||
fill = P.Fill()
|
||||
tile = P.Tile()
|
||||
select = P.Select()
|
||||
size = P.Size()
|
||||
ones_like = P.OnesLike()
|
||||
|
@ -70,6 +71,12 @@ scalar_cast = P.ScalarCast()
|
|||
print_ = P.Print()
|
||||
expand_dims = P.ExpandDims()
|
||||
scatter_nd = P.ScatterNd()
|
||||
gather = P.GatherV2()
|
||||
gather_nd = P.GatherNd()
|
||||
scatter_update = P.ScatterUpdate()
|
||||
scatter_nd_update = P.ScatterNdUpdate()
|
||||
pack = P.Pack()
|
||||
|
||||
|
||||
tuple_setitem = Primitive('tuple_setitem')
|
||||
tuple_getitem = Primitive('tuple_getitem')
|
||||
|
|
|
@ -24,7 +24,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
|||
Fill, GatherNd, GatherV2, InvertPermutation,
|
||||
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
|
||||
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
|
||||
SameTypeShape, ScatterMax,
|
||||
SameTypeShape, ScatterMax, ScatterUpdate,
|
||||
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
||||
Shape, Size, Slice, Split,
|
||||
Squeeze, StridedSlice, Tile,
|
||||
|
@ -193,6 +193,7 @@ __all__ = [
|
|||
'Pad',
|
||||
'MirrorPad',
|
||||
'GatherNd',
|
||||
'ScatterUpdate',
|
||||
'ScatterNdUpdate',
|
||||
'Floor',
|
||||
'NMSWithMask',
|
||||
|
|
|
@ -19,7 +19,7 @@ from ..._c_expression import signature_rw as sig_rw
|
|||
from ..._c_expression import signature_kind as sig_kind
|
||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||
from ..._checkparam import Validator as validator, Rel
|
||||
from .._utils import _get_concat_offset
|
||||
from .._utils import get_concat_offset
|
||||
from ...common import dtype as mstype
|
||||
|
||||
|
||||
|
@ -136,7 +136,7 @@ class ConcatOffset(PrimitiveWithInfer):
|
|||
axis = self.axis
|
||||
x_shp = input_x['shape']
|
||||
x_type = input_x['dtype']
|
||||
offset, _, axis = _get_concat_offset(x_shp, x_type, axis, self.name)
|
||||
offset, _, axis = get_concat_offset(x_shp, x_type, axis, self.name)
|
||||
self.add_prim_attr('T', x_type[0].element_type())
|
||||
offset_values = []
|
||||
for i in range(len(x_shp)):
|
||||
|
|
|
@ -24,16 +24,15 @@ import itertools
|
|||
import numbers
|
||||
import numpy as np
|
||||
|
||||
from ..._c_expression import signature_rw as sig_rw
|
||||
from ..._c_expression import signature_kind as sig_kind
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
from ...common.tensor import Tensor
|
||||
from ..operations.math_ops import _infer_shape_reduce
|
||||
from .._utils import _get_concat_offset
|
||||
from .._utils import get_concat_offset
|
||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||
|
||||
|
||||
def _check_infer_attr_reduce(axis, keep_dims, prim_name):
|
||||
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
|
||||
validator.check_value_type('axis', axis, [int, tuple], prim_name)
|
||||
|
@ -931,7 +930,7 @@ class InvertPermutation(PrimitiveWithInfer):
|
|||
z = [x_value[i] for i in range(len(x_value))]
|
||||
z.sort()
|
||||
|
||||
y = [None]*len(x_value)
|
||||
y = [None] * len(x_value)
|
||||
for i, value in enumerate(x_value):
|
||||
validator.check_value_type("input[%d]" % i, value, [int], self.name)
|
||||
validator.check(f'value', z[i], f'index', i, Rel.EQ, self.name)
|
||||
|
@ -1111,6 +1110,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
|
|||
>>> input_x = Tensor(np.random.rand(5))
|
||||
>>> index, output = P.ArgMinWithValue()(input_x)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, axis=0, keep_dims=False):
|
||||
"""init ArgMinWithValue"""
|
||||
|
@ -1352,7 +1352,7 @@ class Concat(PrimitiveWithInfer):
|
|||
axis = self.axis
|
||||
x_shp = input_x['shape']
|
||||
x_type = input_x['dtype']
|
||||
_, all_shp, _ = _get_concat_offset(x_shp, x_type, axis, self.name)
|
||||
_, all_shp, _ = get_concat_offset(x_shp, x_type, axis, self.name)
|
||||
self.add_prim_attr('T', x_type[0].element_type())
|
||||
self.add_prim_attr('inputNums', len(x_shp))
|
||||
ret_shp = x_shp[0].copy()
|
||||
|
@ -1376,15 +1376,13 @@ def _get_pack_shape(x_shape, x_type, axis, prim_name):
|
|||
if axis < 0:
|
||||
axis = axis + rank_base + 1
|
||||
for i in range(1, N):
|
||||
v = x_shape[i]
|
||||
validator.check('len of x_shape[%d]' % i, len(v), 'len of rank_base', rank_base, Rel.EQ, prim_name)
|
||||
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, prim_name, TypeError)
|
||||
for j in range(rank_base):
|
||||
if v[j] != x_shape[0][j]:
|
||||
raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not pack with first element")
|
||||
if x_shape[i] != x_shape[0]:
|
||||
raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not pack with first element")
|
||||
out_shape.insert(axis, N)
|
||||
return out_shape
|
||||
|
||||
|
||||
class Pack(PrimitiveWithInfer):
|
||||
r"""
|
||||
Packs a list of tensors in specified axis.
|
||||
|
@ -1831,7 +1829,7 @@ class DiagPart(PrimitiveWithInfer):
|
|||
return x_type
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
if len(x_shape)%2 != 0 or \
|
||||
if len(x_shape) % 2 != 0 or \
|
||||
not x_shape:
|
||||
raise ValueError(f"For \'{self.name}\' input rank must be non-zero and even, but got rank {len(x_shape)}, "
|
||||
f"with shapes {x_shape}")
|
||||
|
@ -2004,6 +2002,49 @@ class GatherNd(PrimitiveWithInfer):
|
|||
return x_dtype
|
||||
|
||||
|
||||
class ScatterUpdate(PrimitiveWithInfer):
|
||||
"""
|
||||
Update tensor value by using input indices and value.
|
||||
|
||||
Using given values to update tensor value, along with the input indices.
|
||||
|
||||
Args:
|
||||
use_locking (bool): Whether protect the assignment by a lock. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
|
||||
- **indices** (Tensor) - The index of input tensor.
|
||||
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
|
||||
and update.shape = indices.shape + input_x.shape[1:].
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and type as `input_x`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32))
|
||||
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
|
||||
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
|
||||
>>> op = P.ScatterNdUpdate()
|
||||
>>> output = op(input_x, indices, update)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=True):
|
||||
"""Init ScatterNdUpdate"""
|
||||
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
|
||||
|
||||
def infer_shape(self, x_shape, indices_shape, value_shape):
|
||||
if indices_shape + x_shape[1:] != value_shape:
|
||||
raise ValueError('Input value are not match with input indices.')
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
|
||||
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
|
||||
args = {"x": x_dtype, "value": value_dtype}
|
||||
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class ScatterNdUpdate(PrimitiveWithInfer):
|
||||
"""
|
||||
Update tensor value by using input indices and value.
|
||||
|
@ -2028,11 +2069,6 @@ class ScatterNdUpdate(PrimitiveWithInfer):
|
|||
>>> op = P.ScatterNdUpdate()
|
||||
>>> output = op(input_x, indices, update)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('input_x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD),
|
||||
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=True):
|
||||
|
@ -2142,10 +2178,10 @@ class SpaceToDepth(PrimitiveWithInfer):
|
|||
validator.check('x dimension', len(x_shape), '', 4, Rel.EQ)
|
||||
out_shape = copy.deepcopy(x_shape)
|
||||
for i in range(2):
|
||||
if out_shape[i+2] % self.block_size != 0:
|
||||
raise ValueError(f'For \'{self.name}\' input shape[{i+2}] {out_shape[i+2]} should be '
|
||||
if out_shape[i + 2] % self.block_size != 0:
|
||||
raise ValueError(f'For \'{self.name}\' input shape[{i + 2}] {out_shape[i + 2]} should be '
|
||||
f'fully divided by block_size {self.block_size}')
|
||||
out_shape[i+2] //= self.block_size
|
||||
out_shape[i + 2] //= self.block_size
|
||||
|
||||
out_shape[1] *= self.block_size * self.block_size
|
||||
return out_shape
|
||||
|
@ -2199,9 +2235,10 @@ class DepthToSpace(PrimitiveWithInfer):
|
|||
validator.check('x dimension', len(x_shape), '', 4, Rel.EQ)
|
||||
out_shape = copy.deepcopy(x_shape)
|
||||
for i in range(2):
|
||||
out_shape[i+2] *= self.block_size
|
||||
out_shape[i + 2] *= self.block_size
|
||||
|
||||
validator.check_integer('x_shape[1] % (block_size*block_size)', x_shape[1] % (self.block_size*self.block_size),
|
||||
validator.check_integer('x_shape[1] % (block_size*block_size)',
|
||||
x_shape[1] % (self.block_size * self.block_size),
|
||||
0, Rel.EQ, self.name)
|
||||
out_shape[1] //= self.block_size * self.block_size
|
||||
return out_shape
|
||||
|
@ -2251,6 +2288,7 @@ class SpaceToBatch(PrimitiveWithInfer):
|
|||
[[[[1.]]], [[[2.]]], [[[3.]]], [[[4.]]]]
|
||||
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, block_size, paddings):
|
||||
"""Init SpaceToBatch"""
|
||||
|
@ -2271,12 +2309,12 @@ class SpaceToBatch(PrimitiveWithInfer):
|
|||
validator.check_integer('rank of input_x', len(x_shape), 4, Rel.EQ, self.name)
|
||||
out_shape = copy.deepcopy(x_shape)
|
||||
for i in range(2):
|
||||
padded = out_shape[i+2] + self.paddings[i][0] + \
|
||||
padded = out_shape[i + 2] + self.paddings[i][0] + \
|
||||
self.paddings[i][1]
|
||||
if padded % self.block_size != 0:
|
||||
raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by '
|
||||
f'block_size {self.block_size}')
|
||||
out_shape[i+2] = padded // self.block_size
|
||||
out_shape[i + 2] = padded // self.block_size
|
||||
out_shape[0] *= self.block_size * self.block_size
|
||||
return out_shape
|
||||
|
||||
|
@ -2319,6 +2357,7 @@ class BatchToSpace(PrimitiveWithInfer):
|
|||
[[[[1., 2.], [3., 4.]]]]
|
||||
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, block_size, crops):
|
||||
"""Init BatchToSpace"""
|
||||
|
@ -2339,10 +2378,10 @@ class BatchToSpace(PrimitiveWithInfer):
|
|||
validator.check('rank of input_x', len(x_shape), '', 4)
|
||||
out_shape = copy.deepcopy(x_shape)
|
||||
for i in range(2):
|
||||
x_block_prod = out_shape[i+2] * self.block_size
|
||||
x_block_prod = out_shape[i + 2] * self.block_size
|
||||
crops_sum = self.crops[i][0] + self.crops[i][1]
|
||||
validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name)
|
||||
out_shape[i+2] = x_block_prod - crops_sum
|
||||
out_shape[i + 2] = x_block_prod - crops_sum
|
||||
block_size_prod = self.block_size * self.block_size
|
||||
if out_shape[0] % block_size_prod != 0:
|
||||
raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by '
|
||||
|
|
|
@ -24,7 +24,7 @@ from ..._checkparam import Validator as validator
|
|||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
from ...common.tensor import Tensor
|
||||
from .._utils import _get_broadcast_shape
|
||||
from .._utils import get_broadcast_shape
|
||||
from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op
|
||||
|
||||
|
||||
|
@ -75,7 +75,7 @@ class _BinaryOp(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape, y_shape):
|
||||
return _get_broadcast_shape(x_shape, y_shape, self.name)
|
||||
return get_broadcast_shape(x_shape, y_shape, self.name)
|
||||
|
||||
|
||||
class _MathBinaryOp(_BinaryOp):
|
||||
|
|
|
@ -27,9 +27,15 @@ class IdentityEC(IExectorComponent):
|
|||
def __call__(self):
|
||||
result_id = self.function[keyword.id] + '-' + self.inputs[keyword.id]
|
||||
group = self.function[keyword.group] + '-' + self.inputs[keyword.group]
|
||||
return {
|
||||
ret = {
|
||||
keyword.id: result_id,
|
||||
keyword.group: group,
|
||||
keyword.desc_inputs: self.inputs[keyword.desc_inputs],
|
||||
keyword.result: self.function[keyword.block](*self.inputs[keyword.desc_inputs])
|
||||
}
|
||||
print("buxue------------------------------------------------")
|
||||
print("inputs")
|
||||
print(ret[keyword.desc_inputs])
|
||||
print("outputs")
|
||||
print(ret[keyword.result])
|
||||
return ret
|
||||
|
|
|
@ -1307,7 +1307,7 @@ raise_set = [
|
|||
('ScatterNdUpdate', {
|
||||
'block': (P.ScatterNdUpdate(), {'exception': TypeError}),
|
||||
'desc_inputs': (Tensor(np.ones((2, 3), np.float32)),
|
||||
Tensor(np.ones((2, 2), np.int32)),
|
||||
Tensor(np.ones((2, 2), np.float32)),
|
||||
Tensor(np.ones((2,), np.float32))),
|
||||
'desc_bprop': [[2, 3]]}),
|
||||
('Pack', {
|
||||
|
|
|
@ -16,13 +16,14 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import context
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.nn import Cell
|
||||
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
||||
from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
||||
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
|
||||
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \
|
||||
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
|
||||
|
||||
|
||||
class NetWorkSlicePositive(Cell):
|
||||
|
@ -145,6 +146,160 @@ class TensorAssignWithSlice(Cell):
|
|||
return z
|
||||
|
||||
|
||||
class TensorIndexByOneTensor(Cell):
|
||||
def __init__(self):
|
||||
super(TensorIndexByOneTensor, self).__init__()
|
||||
self.const = Tensor(np.ones((5, 4, 7, 8)), mstype.int32)
|
||||
|
||||
def construct(self, x, index):
|
||||
ret = x[index] + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorIndexByTwoTensors(Cell):
|
||||
def __init__(self):
|
||||
super(TensorIndexByTwoTensors, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 5, 8)), mstype.int32)
|
||||
|
||||
def construct(self, x, index_0, index_1):
|
||||
ret = x[index_0, index_1] + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorIndexByThreeTensors(Cell):
|
||||
def __init__(self):
|
||||
super(TensorIndexByThreeTensors, self).__init__()
|
||||
self.const = Tensor(np.ones((5, 3, 4, 5)), mstype.int32)
|
||||
|
||||
def construct(self, x, index_0, index_1, index_2):
|
||||
ret = x[index_0, index_1, index_2] + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByOneTensorWithNumber(Cell):
|
||||
def __init__(self, value):
|
||||
super(TensorSetItemByOneTensorWithNumber, self).__init__()
|
||||
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||
self.value = value
|
||||
|
||||
def construct(self, index):
|
||||
self.param[index] = self.value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByOneTensorWithTensor(Cell):
|
||||
def __init__(self):
|
||||
super(TensorSetItemByOneTensorWithTensor, self).__init__()
|
||||
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||
|
||||
def construct(self, index, value):
|
||||
self.param[index] = value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByOneTensorWithTupleOfNumber(Cell):
|
||||
def __init__(self, value):
|
||||
super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__()
|
||||
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||
self.value = value
|
||||
|
||||
def construct(self, index):
|
||||
self.param[index] = self.value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByOneTensorWithTupleOfTensor(Cell):
|
||||
def __init__(self):
|
||||
super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__()
|
||||
self.const = Tensor(np.ones((6, 3, 8)), mstype.float32)
|
||||
self.param = Parameter(Tensor(np.arange(6*3*8).reshape((6, 3, 8)), mstype.float32), name="x")
|
||||
|
||||
def construct(self, index, value_0, value_1, value_2):
|
||||
self.param[index] = (value_0, value_1, value_2)
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByTensorsWithNumber(Cell):
|
||||
def __init__(self, value):
|
||||
super(TensorSetItemByTensorsWithNumber, self).__init__()
|
||||
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||
self.value = value
|
||||
|
||||
def construct(self, index_0, index_1, index_2):
|
||||
self.param[index_0, index_1, index_2] = self.value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByTensorsWithTensor(Cell):
|
||||
def __init__(self):
|
||||
super(TensorSetItemByTensorsWithTensor, self).__init__()
|
||||
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||
|
||||
def construct(self, index_0, index_1, index_2, value):
|
||||
self.param[index_0, index_1, index_2] = value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByTensorsWithTensorNumberError(Cell):
|
||||
def __init__(self):
|
||||
super(TensorSetItemByTensorsWithTensorNumberError, self).__init__()
|
||||
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||
|
||||
def construct(self, index_0, index_1, index_2, index_3, value):
|
||||
self.param[index_0, index_1, index_2, index_3] = value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByTensorsWithTupleOfNumber(Cell):
|
||||
def __init__(self, value):
|
||||
super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__()
|
||||
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||
self.value = value
|
||||
|
||||
def construct(self, index_0, index_1, index_2):
|
||||
self.param[index_0, index_1, index_2] = self.value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByTensorsWithTupleOfTensor(Cell):
|
||||
def __init__(self):
|
||||
super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__()
|
||||
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||
|
||||
def construct(self, index_0, index_1, index_2, value_0, value_1, value_2):
|
||||
self.param[index_0, index_1, index_2] = (value_0, value_1, value_2)
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell):
|
||||
def __init__(self):
|
||||
super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__()
|
||||
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||
|
||||
def construct(self, index_0, index_1, index_2, value_0, value_1):
|
||||
self.param[index_0, index_1, index_2] = (value_0, value_1)
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
def test_tensor_assign():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
net = TensorAssignWithSlice()
|
||||
|
@ -441,15 +596,206 @@ test_cases = [
|
|||
'block': NetWorkSliceEllipsis(),
|
||||
'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))],
|
||||
}),
|
||||
('TensorIndexByOneTensor', {
|
||||
'block': TensorIndexByOneTensor(),
|
||||
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||
Tensor(np.random.randint(6, size=(5, 4)), mstype.int32)],
|
||||
}),
|
||||
('TensorIndexByTwoTensors', {
|
||||
'block': TensorIndexByTwoTensors(),
|
||||
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorIndexByThreeTensors', {
|
||||
'block': TensorIndexByThreeTensors(),
|
||||
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByOneTensorWithNumber', {
|
||||
'block': TensorSetItemByOneTensorWithNumber(value=0.0),
|
||||
'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByOneTensorWithTensor', {
|
||||
'block': TensorSetItemByOneTensorWithTensor(),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32),
|
||||
Tensor(np.zeros((4, 7, 8)), mstype.float32)],
|
||||
}),
|
||||
('TensorSetItemByOneTensorWithTupleOfNumber', {
|
||||
'block': TensorSetItemByOneTensorWithTupleOfNumber(value=(0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7)),
|
||||
'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByOneTensorWithTupleOfTensor', {
|
||||
'block': TensorSetItemByOneTensorWithTupleOfTensor(),
|
||||
'desc_inputs': [Tensor(np.random.randint(6, size=(5, 4)), mstype.int32),
|
||||
Tensor(np.zeros((8,), np.float32)),
|
||||
Tensor(np.ones((8,), np.float32)),
|
||||
Tensor(np.ones((8,), np.float32) * 2)],
|
||||
}),
|
||||
('TensorSetItemByTensorsWithNumber', {
|
||||
'block': TensorSetItemByTensorsWithNumber(value=0.0),
|
||||
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByTensorsWithTensor', {
|
||||
'block': TensorSetItemByTensorsWithTensor(),
|
||||
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
|
||||
Tensor(np.zeros((4, 5)), mstype.float32)],
|
||||
}),
|
||||
('TensorSetItemByTensorsWithTupleOfNumber', {
|
||||
'block': TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.1, 2.2, 3.3, 4.4)),
|
||||
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByTensorsWithTupleOfTensor', {
|
||||
'block': TensorSetItemByTensorsWithTupleOfTensor(),
|
||||
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
|
||||
Tensor(np.zeros((4, 5)), mstype.float32),
|
||||
Tensor(np.ones((4, 5)), mstype.float32),
|
||||
Tensor(np.ones((4, 5)) * 2, mstype.float32)],
|
||||
})
|
||||
]
|
||||
|
||||
raise_error_set = [
|
||||
('TensorIndexByOneTensorDtypeError', {
|
||||
'block': (TensorIndexByOneTensor(), {'exception': TypeError}),
|
||||
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||
Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)],
|
||||
}),
|
||||
('TensorIndexByTwoTensorsShapeError', {
|
||||
'block': (TensorIndexByTwoTensors(), {'exception': ValueError}),
|
||||
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorIndexByTwoTensorsDtypeError', {
|
||||
'block': (TensorIndexByTwoTensors(), {'exception': TypeError}),
|
||||
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)],
|
||||
}),
|
||||
('TensorIndexByThreeTensorsShapeError', {
|
||||
'block': (TensorIndexByThreeTensors(), {'exception': ValueError}),
|
||||
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorIndexByThreeTensorsDtypeError', {
|
||||
'block': (TensorIndexByThreeTensors(), {'exception': TypeError}),
|
||||
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64),
|
||||
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByOneTensorWithNumberTypeError', {
|
||||
'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByOneTensorWithTensorShapeError', {
|
||||
'block': (TensorSetItemByOneTensorWithTensor(), {'exception': ValueError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32),
|
||||
Tensor(np.zeros((6, 7, 8)), mstype.float32)],
|
||||
}),
|
||||
('TensorSetItemByOneTensorWithTensorDtypeError', {
|
||||
'block': (TensorSetItemByOneTensorWithTensor(), {'exception': TypeError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32),
|
||||
Tensor(np.zeros((6, 7, 8)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByOneTensorWithTupleOfNumberTypeError', {
|
||||
'block': (TensorSetItemByOneTensorWithTupleOfNumber(value=(0, 1, 2, 3, 4, 5, 6, 7)), {'exception': TypeError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByOneTensorWithTupleOfNumberNumberError', {
|
||||
'block': (TensorSetItemByOneTensorWithTupleOfNumber(value=(0.0, 1.1, 2.2)), {'exception': ValueError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByOneTensorWithTupleOfTensorDtyeError', {
|
||||
'block': (TensorSetItemByOneTensorWithTupleOfTensor(), {'exception': TypeError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(6, size=(5, 4)), mstype.int32),
|
||||
Tensor(np.zeros((8,), np.int32)),
|
||||
Tensor(np.ones((8,), np.int32)),
|
||||
Tensor(np.ones((8,), np.float32) * 2)],
|
||||
}),
|
||||
('TensorSetItemByTensorsWithNumberTypeError', {
|
||||
'block': (TensorSetItemByTensorsWithNumber(value=0), {'exception': TypeError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByTensorsWithTensorShapeError', {
|
||||
'block': (TensorSetItemByTensorsWithTensor(), {'exception': ValueError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
|
||||
Tensor(np.zeros((2, 5)), mstype.float32)],
|
||||
}),
|
||||
('TensorSetItemByTensorsWithTensorTypeError', {
|
||||
'block': (TensorSetItemByTensorsWithTensor(), {'exception': TypeError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
|
||||
Tensor(np.zeros((4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByTensorsWithTensorNumberError', {
|
||||
'block': (TensorSetItemByTensorsWithTensorNumberError(), {'exception': IndexError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(8, size=(1, 3, 4, 5)), mstype.int32),
|
||||
Tensor(np.zeros((2, 5)), mstype.float32)],
|
||||
}),
|
||||
('TensorSetItemByTensorsWithTupleOfNumberTypeError', {
|
||||
'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0, 1, 2, 3, 4)), {'exception': TypeError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByTensorsWithTupleOfNumberNumberError', {
|
||||
'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.0, 2.0, 3.0)), {'exception': ValueError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByTensorsWithTupleOfTensorNumberError', {
|
||||
'block': (TensorSetItemByTensorsWithTupleOfTensorNumberError(), {'exception': ValueError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
|
||||
Tensor(np.zeros((4, 5)), mstype.float32),
|
||||
Tensor(np.ones((4, 5)), mstype.float32)],
|
||||
}),
|
||||
('TensorSetItemByTensorsWithTupleOfTensorTypeError', {
|
||||
'block': (TensorSetItemByTensorsWithTupleOfTensor(), {'exception': TypeError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
|
||||
Tensor(np.zeros((4, 5)), mstype.float32),
|
||||
Tensor(np.ones((4, 5)), mstype.int32),
|
||||
Tensor(np.ones((4, 5)) * 2, mstype.int32)],
|
||||
})
|
||||
]
|
||||
|
||||
|
||||
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
|
||||
def test_compile():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
def test_exec():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
return test_cases
|
||||
|
||||
|
||||
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
|
||||
def test_check_exception():
|
||||
return raise_error_set
|
||||
|
||||
|
||||
def test_tensor_slice_reduce_out_of_bounds_neg():
|
||||
class NetWork(Cell):
|
||||
def __init__(self):
|
||||
|
|
|
@ -26,7 +26,7 @@ from mindspore.ops import functional as F
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.ops._grad.grad_base import bprop_getters
|
||||
from mindspore.ops._grad.grad_math_ops import binop_grad_common
|
||||
from mindspore.ops._utils import _get_broadcast_shape
|
||||
from mindspore.ops._utils import get_broadcast_shape
|
||||
from mindspore.ops.primitive import PrimitiveWithInfer, prim_attr_register
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
||||
|
||||
|
@ -54,7 +54,7 @@ class MockSub(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape, y_shape):
|
||||
return _get_broadcast_shape(x_shape, y_shape)
|
||||
return get_broadcast_shape(x_shape, y_shape)
|
||||
|
||||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
return x_dtype
|
||||
|
|
Loading…
Reference in New Issue