forked from mindspore-Ecosystem/mindspore
support tensor get value by tensor index
support tensor set value by tensor index
This commit is contained in:
parent
ca74e624e2
commit
e490618db8
|
@ -1172,6 +1172,12 @@ int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, co
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) {
|
||||||
|
auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional");
|
||||||
|
ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph));
|
||||||
|
return ret_graph;
|
||||||
|
}
|
||||||
|
|
||||||
FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||||
// slice a tensor
|
// slice a tensor
|
||||||
// args: tensor, slice or slice tuple
|
// args: tensor, slice or slice tuple
|
||||||
|
@ -1229,12 +1235,6 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
|
||||||
return ret_graph;
|
return ret_graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr TensorSlice::ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const {
|
|
||||||
auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional");
|
|
||||||
ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph));
|
|
||||||
return ret_graph;
|
|
||||||
}
|
|
||||||
|
|
||||||
FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||||
// select indexed item
|
// select indexed item
|
||||||
// args: tuple of items, index
|
// args: tuple of items, index
|
||||||
|
|
|
@ -206,8 +206,6 @@ class TensorSlice : public MetaFuncGraph {
|
||||||
MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph)
|
MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph)
|
||||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||||
friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; }
|
friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; }
|
||||||
|
|
||||||
FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const;
|
|
||||||
};
|
};
|
||||||
using TensorSlicePtr = std::shared_ptr<TensorSlice>;
|
using TensorSlicePtr = std::shared_ptr<TensorSlice>;
|
||||||
|
|
||||||
|
|
|
@ -101,6 +101,7 @@ const char kNameReLU6[] = "ReLU6";
|
||||||
const char kNameReLU6Grad[] = "ReLU6Grad";
|
const char kNameReLU6Grad[] = "ReLU6Grad";
|
||||||
const char kNameElu[] = "Elu";
|
const char kNameElu[] = "Elu";
|
||||||
const char kNameEluGrad[] = "EluGrad";
|
const char kNameEluGrad[] = "EluGrad";
|
||||||
|
const char kNameScatterUpdate[] = "ScatterUpdate";
|
||||||
const char kNameScatterNdUpdate[] = "ScatterNdUpdate";
|
const char kNameScatterNdUpdate[] = "ScatterNdUpdate";
|
||||||
const char kNameScatterMax[] = "ScatterMax";
|
const char kNameScatterMax[] = "ScatterMax";
|
||||||
const char kNameNMSWithMask[] = "NMSWithMask";
|
const char kNameNMSWithMask[] = "NMSWithMask";
|
||||||
|
@ -256,6 +257,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
||||||
{string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)},
|
{string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)},
|
||||||
{string(kNameZerosLike), ADPT_DESC(ZerosLike)},
|
{string(kNameZerosLike), ADPT_DESC(ZerosLike)},
|
||||||
{string(kNameOnesLike), ADPT_DESC(OnesLike)},
|
{string(kNameOnesLike), ADPT_DESC(OnesLike)},
|
||||||
|
{string(kNameScatterUpdate), ADPT_DESC(ScatterUpdate)},
|
||||||
{string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)},
|
{string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)},
|
||||||
{string(kNameScatterMax), ADPT_DESC(ScatterMax)},
|
{string(kNameScatterMax), ADPT_DESC(ScatterMax)},
|
||||||
{string(kNameNMSWithMask), ADPT_DESC(NMSWithMask)},
|
{string(kNameNMSWithMask), ADPT_DESC(NMSWithMask)},
|
||||||
|
|
|
@ -515,6 +515,11 @@ INPUT_MAP(Unpack) = {{1, INPUT_DESC(x)}};
|
||||||
ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits<int>())}, {"num", ATTR_DESC(num, AnyTraits<int>())}};
|
ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits<int>())}, {"num", ATTR_DESC(num, AnyTraits<int>())}};
|
||||||
DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}};
|
DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}};
|
||||||
|
|
||||||
|
// ScatterUpdate
|
||||||
|
INPUT_MAP(ScatterUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
|
||||||
|
ATTR_MAP(ScatterUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||||
|
OUTPUT_MAP(ScatterUpdate) = {{0, OUTPUT_DESC(var)}};
|
||||||
|
|
||||||
// ScatterNdUpdate
|
// ScatterNdUpdate
|
||||||
INPUT_MAP(ScatterNdUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
|
INPUT_MAP(ScatterNdUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
|
||||||
ATTR_MAP(ScatterNdUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
ATTR_MAP(ScatterNdUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||||
|
|
|
@ -132,6 +132,8 @@ DECLARE_OP_ADAPTER(ZerosLike)
|
||||||
DECLARE_OP_USE_OUTPUT(ZerosLike)
|
DECLARE_OP_USE_OUTPUT(ZerosLike)
|
||||||
DECLARE_OP_ADAPTER(OnesLike)
|
DECLARE_OP_ADAPTER(OnesLike)
|
||||||
DECLARE_OP_USE_OUTPUT(OnesLike)
|
DECLARE_OP_USE_OUTPUT(OnesLike)
|
||||||
|
DECLARE_OP_ADAPTER(ScatterUpdate)
|
||||||
|
DECLARE_OP_USE_OUTPUT(ScatterUpdate)
|
||||||
DECLARE_OP_ADAPTER(ScatterNdUpdate)
|
DECLARE_OP_ADAPTER(ScatterNdUpdate)
|
||||||
DECLARE_OP_USE_OUTPUT(ScatterNdUpdate)
|
DECLARE_OP_USE_OUTPUT(ScatterNdUpdate)
|
||||||
DECLARE_OP_ADAPTER(ScatterMax)
|
DECLARE_OP_ADAPTER(ScatterMax)
|
||||||
|
|
|
@ -178,13 +178,14 @@ from .bounding_box_encode import _bounding_box_encode_tbe
|
||||||
from .check_valid import _check_valid_tbe
|
from .check_valid import _check_valid_tbe
|
||||||
from .iou import _iou_tbe
|
from .iou import _iou_tbe
|
||||||
from .arg_max import _arg_max_tbe
|
from .arg_max import _arg_max_tbe
|
||||||
from .nms_with_mask import nms_with_mask_op_info
|
from .nms_with_mask import _nms_with_mask_tbe
|
||||||
from .random_choice_with_mask import random_choice_with_mask_op_info
|
from .random_choice_with_mask import _random_choice_with_mask_tbe
|
||||||
from .sgd import sgd_op_info
|
from .sgd import _sgd_tbe
|
||||||
from .lars_update import lars_update_op_info
|
from .lars_update import _lars_update_tbe
|
||||||
from .bn_training_update_v2 import _bn_training_update_v2_tbe
|
from .bn_training_update_v2 import _bn_training_update_v2_tbe
|
||||||
from .square_sum_all import square_sum_all_op_info
|
from .square_sum_all import _square_sum_all_tbe
|
||||||
from .pack import _pack_tbe
|
from .pack import _pack_tbe
|
||||||
from .unpack import _unpack_tbe
|
from .unpack import _unpack_tbe
|
||||||
|
from .scatter_update import _scatter_update_tbe
|
||||||
from .prelu import _prelu_tbe
|
from .prelu import _prelu_tbe
|
||||||
from .prelu_grad import _prelu_grad_tbe
|
from .prelu_grad import _prelu_grad_tbe
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""ScatterUpdate op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
scatter_update_op_info = TBERegOp("ScatterUpdate") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("scatter_update.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("scatter_update") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.attr("use_locking", "optional", "bool", "all") \
|
||||||
|
.input(0, "var", False, "required", "all") \
|
||||||
|
.input(1, "indices", False, "required", "all") \
|
||||||
|
.input(1, "updates", False, "required", "all") \
|
||||||
|
.output(0, "var", False, "required", "all") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(scatter_update_op_info)
|
||||||
|
def _scatter_update_tbe():
|
||||||
|
"""ScatterUpdate TBE register"""
|
||||||
|
return
|
|
@ -14,6 +14,6 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""ops utils."""
|
"""ops utils."""
|
||||||
from .utils import _get_broadcast_shape, _get_concat_offset
|
from .utils import get_broadcast_shape, get_concat_offset
|
||||||
|
|
||||||
__all__ = ['_get_broadcast_shape', '_get_concat_offset']
|
__all__ = ['get_broadcast_shape', 'get_concat_offset']
|
||||||
|
|
|
@ -19,7 +19,8 @@ from ..._checkparam import Validator as validator
|
||||||
from ..._checkparam import Rel
|
from ..._checkparam import Rel
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
|
|
||||||
def _get_broadcast_shape(x_shape, y_shape, prim_name):
|
|
||||||
|
def get_broadcast_shape(x_shape, y_shape, prim_name):
|
||||||
"""
|
"""
|
||||||
Doing broadcast between tensor x and tensor y.
|
Doing broadcast between tensor x and tensor y.
|
||||||
|
|
||||||
|
@ -37,7 +38,7 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name):
|
||||||
Examples:
|
Examples:
|
||||||
>>> x_shape = [1, 2, 3]
|
>>> x_shape = [1, 2, 3]
|
||||||
>>> y_shape = [1, 2]
|
>>> y_shape = [1, 2]
|
||||||
>>> broadcast_shape = _get_broadcast_shape(x_shape, y_shape)
|
>>> broadcast_shape = get_broadcast_shape(x_shape, y_shape)
|
||||||
"""
|
"""
|
||||||
if x_shape == y_shape:
|
if x_shape == y_shape:
|
||||||
return x_shape
|
return x_shape
|
||||||
|
@ -54,15 +55,14 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name):
|
||||||
elif x_shape[i] == y_shape[i]:
|
elif x_shape[i] == y_shape[i]:
|
||||||
broadcast_shape_back.append(x_shape[i])
|
broadcast_shape_back.append(x_shape[i])
|
||||||
else:
|
else:
|
||||||
raise ValueError("For '{}' the x_shape {} and y_shape {} can not broadcast.".format(
|
raise ValueError(f"For '{prim_name}', the x_shape {x_shape} and y_shape {y_shape} can not broadcast.")
|
||||||
prim_name, x_shape, y_shape))
|
|
||||||
|
|
||||||
broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
|
broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
|
||||||
broadcast_shape = broadcast_shape_front + broadcast_shape_back
|
broadcast_shape = list(broadcast_shape_front) + broadcast_shape_back
|
||||||
return broadcast_shape
|
return broadcast_shape
|
||||||
|
|
||||||
|
|
||||||
def _get_concat_offset(x_shp, x_type, axis, prim_name):
|
def get_concat_offset(x_shp, x_type, axis, prim_name):
|
||||||
"""for concat and concatoffset check args and compute offset"""
|
"""for concat and concatoffset check args and compute offset"""
|
||||||
validator.check_value_type("shape", x_shp, [tuple], prim_name)
|
validator.check_value_type("shape", x_shp, [tuple], prim_name)
|
||||||
validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name)
|
validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name)
|
||||||
|
@ -73,7 +73,7 @@ def _get_concat_offset(x_shp, x_type, axis, prim_name):
|
||||||
if axis < 0:
|
if axis < 0:
|
||||||
axis = axis + rank_base
|
axis = axis + rank_base
|
||||||
all_shp = x_shp[0][axis]
|
all_shp = x_shp[0][axis]
|
||||||
offset = [0,]
|
offset = [0]
|
||||||
for i in range(1, len(x_shp)):
|
for i in range(1, len(x_shp)):
|
||||||
v = x_shp[i]
|
v = x_shp[i]
|
||||||
validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name)
|
validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name)
|
||||||
|
|
|
@ -1,226 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
"""constexpr util"""
|
|
||||||
|
|
||||||
from functools import reduce
|
|
||||||
import numpy as np
|
|
||||||
from ...primitive import constexpr
|
|
||||||
from ....common.tensor import Tensor
|
|
||||||
from ....common import dtype as mstype
|
|
||||||
from ...._extends.utils import Slice, Ellipsis_
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def check_equal(param1, param2, msg="{},{}"):
|
|
||||||
"""Checks whether the two parameters are equal or not."""
|
|
||||||
if param1 != param2:
|
|
||||||
raise ValueError(msg.format(param1, param2))
|
|
||||||
return param1
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
|
|
||||||
"""Checks the shape and size of the sensor and value."""
|
|
||||||
if data_shape == value_shape or data_size == value_size or value_size == 1:
|
|
||||||
return True
|
|
||||||
raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(value_shape, data_shape))
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def check_tensor_setitem_index(index, element_type=None):
|
|
||||||
"""Checks tuple index type of tensor assignment."""
|
|
||||||
if index is None:
|
|
||||||
raise IndexError("Tensor's index cannot be None.")
|
|
||||||
# eg. Tensor[Slice] = u
|
|
||||||
if isinstance(index, Slice):
|
|
||||||
return True
|
|
||||||
# eg. Tensor[tuple] = u
|
|
||||||
if isinstance(index, tuple):
|
|
||||||
if not index:
|
|
||||||
raise IndexError("Tensor's index cannot be empty.")
|
|
||||||
# eg. Tensor[tuple(Slice...)] = u
|
|
||||||
if isinstance(index[0], (Slice, Ellipsis_, int)):
|
|
||||||
return True
|
|
||||||
raise IndexError("Index of type '{}' is not supported yet.".format(type(index[0])))
|
|
||||||
# eg. Tensor[Tensor[dtype=bool]] = u
|
|
||||||
if index == mstype.tensor:
|
|
||||||
if element_type is None or element_type != mstype.bool_:
|
|
||||||
raise TypeError(
|
|
||||||
"The index of tensor should be a bool type tensor. "
|
|
||||||
"{} type is not supported yet.".format(element_type))
|
|
||||||
return True
|
|
||||||
|
|
||||||
raise IndexError("Index of type '{}' is not supported yet.".format(type(index)))
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def is_same_type(inst, type_):
|
|
||||||
"""
|
|
||||||
Checks whether an object is an instance of a target type.
|
|
||||||
|
|
||||||
Inputs:
|
|
||||||
inst (mindspore.dtype): Inspected type.
|
|
||||||
type_ (mindspore.dtype): Target type.
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
bool, the check result.
|
|
||||||
"""
|
|
||||||
return inst == type_
|
|
||||||
|
|
||||||
|
|
||||||
def slice_expand(input_slices, shape):
|
|
||||||
"""
|
|
||||||
Converts slice to indices.
|
|
||||||
|
|
||||||
Inputs:
|
|
||||||
slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
|
|
||||||
shape (tuple): The shape of a sensor is an integer element tuple.
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
tuple[list], This is expressed as (begins, ends, strides).
|
|
||||||
"""
|
|
||||||
begin = []
|
|
||||||
end = []
|
|
||||||
strides = []
|
|
||||||
index = 0
|
|
||||||
slices = None
|
|
||||||
# Slice or tuple(Slice...)
|
|
||||||
if isinstance(input_slices, Slice):
|
|
||||||
slices = (input_slices,)
|
|
||||||
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (Slice, Ellipsis_)):
|
|
||||||
is_have_ellipsis = False
|
|
||||||
for _, element in enumerate(input_slices):
|
|
||||||
if isinstance(element, Ellipsis_):
|
|
||||||
is_have_ellipsis = True
|
|
||||||
break
|
|
||||||
if is_have_ellipsis:
|
|
||||||
slices = ellipsis2slice(input_slices, shape)
|
|
||||||
else:
|
|
||||||
slices = input_slices
|
|
||||||
else:
|
|
||||||
raise IndexError("Tensor's index type is not supported yet.")
|
|
||||||
|
|
||||||
for s in slices:
|
|
||||||
start = 0 if (s.start is None) else s.start
|
|
||||||
stop = shape[index] if (s.end is None) else s.end
|
|
||||||
step = 1 if (s.step is None) else s.step
|
|
||||||
begin.append(start)
|
|
||||||
end.append(stop)
|
|
||||||
strides.append(step)
|
|
||||||
index += 1
|
|
||||||
while index < len(shape):
|
|
||||||
begin.append(0)
|
|
||||||
end.append(shape[index])
|
|
||||||
strides.append(1)
|
|
||||||
index += 1
|
|
||||||
return begin, end, strides
|
|
||||||
|
|
||||||
|
|
||||||
def ellipsis2slice(input_, shape):
|
|
||||||
"""Converts ellipsis to slice."""
|
|
||||||
input_slice = input_
|
|
||||||
result = []
|
|
||||||
if isinstance(input_, Ellipsis_):
|
|
||||||
input_slice = (input_,)
|
|
||||||
ell_count = 0
|
|
||||||
for _, element in enumerate(input_slice):
|
|
||||||
if not isinstance(element, Ellipsis_):
|
|
||||||
result.append(element)
|
|
||||||
continue
|
|
||||||
ell_count += 1
|
|
||||||
if ell_count > 1:
|
|
||||||
raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, "
|
|
||||||
"but it is currently {}".format(input_slice))
|
|
||||||
for _ in range(len(shape) - len(input_slice) + 1):
|
|
||||||
result.append(Slice(None, None, None))
|
|
||||||
return tuple(result)
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def slice2indices(input_slices, shape):
|
|
||||||
"""
|
|
||||||
Converts slice to indices.
|
|
||||||
|
|
||||||
Inputs:
|
|
||||||
slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
|
|
||||||
shape (tuple): The shape of a tensor is an integer element tuple.
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
Tensor, the shape is (n, 1).
|
|
||||||
"""
|
|
||||||
begin, end, strides = slice_expand(input_slices, shape)
|
|
||||||
np_r = []
|
|
||||||
for i, element in enumerate(shape):
|
|
||||||
s = begin[i] if (begin[i] >= 0) else (element + begin[i])
|
|
||||||
e = end[i] if (end[i] >= 0) else (element + end[i])
|
|
||||||
np_r.append(np.r_[s:e:strides[i]])
|
|
||||||
# Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape)
|
|
||||||
np_ix = np.ix_(*np_r)
|
|
||||||
ravel = np.ravel_multi_index(np_ix, shape)
|
|
||||||
ravel = Tensor(ravel.reshape(-1, 1), dtype=mstype.int32)
|
|
||||||
return ravel
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def check_indices(indices_size, index):
|
|
||||||
"""Checks indices whether is empty."""
|
|
||||||
if indices_size < 1:
|
|
||||||
raise IndexError("The tensor's index is unreasonable. index:{}".format(index))
|
|
||||||
return indices_size
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def check_indices_value_size(indices_size, value_size):
|
|
||||||
"""Checks if the sizes are already matched."""
|
|
||||||
if value_size < 1:
|
|
||||||
raise ValueError("The value assigned to tensor cannot be empty.")
|
|
||||||
if value_size > 1:
|
|
||||||
if value_size != indices_size:
|
|
||||||
raise ValueError(
|
|
||||||
"The value given to tensor does not match the index size,"
|
|
||||||
" value size:{}, indics size:{}".format(value_size, indices_size))
|
|
||||||
return value_size
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def integer_to_indices(index, shape):
|
|
||||||
"""Converts int or tuple[int] to indices."""
|
|
||||||
size = reduce(lambda x, y: x * y, shape)
|
|
||||||
range_ = np.arange(size).reshape(shape)
|
|
||||||
value = range_[index]
|
|
||||||
value = value.reshape(-1, 1)
|
|
||||||
return Tensor(value, dtype=mstype.int32)
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def tuple_element_is_slice(indexs):
|
|
||||||
"""Judges tuple element type."""
|
|
||||||
if not indexs:
|
|
||||||
raise IndexError("Tensor's index cannot be empty.")
|
|
||||||
if isinstance(indexs, tuple):
|
|
||||||
for _, ele in enumerate(indexs):
|
|
||||||
if not isinstance(ele, Slice):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def tuple_element_is_int(indexs):
|
|
||||||
"""Judges tuple element type."""
|
|
||||||
if not indexs:
|
|
||||||
raise IndexError("Tensor's index cannot be empty.")
|
|
||||||
if isinstance(indexs, tuple):
|
|
||||||
for _, ele in enumerate(indexs):
|
|
||||||
if not isinstance(ele, int):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
return False
|
|
|
@ -0,0 +1,487 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""constexpr util"""
|
||||||
|
from functools import reduce
|
||||||
|
import numpy as np
|
||||||
|
from ...primitive import constexpr
|
||||||
|
from ....common.tensor import Tensor
|
||||||
|
from ....common import dtype as mstype
|
||||||
|
from ...._extends.utils import Slice, Ellipsis_
|
||||||
|
from ....ops import _utils as op_utils
|
||||||
|
from ...composite import base
|
||||||
|
from .... import log as logger
|
||||||
|
from ... import functional as F
|
||||||
|
from ... import operations as P
|
||||||
|
|
||||||
|
hyper_map = base.HyperMap()
|
||||||
|
pack = P.Pack(axis=-1)
|
||||||
|
|
||||||
|
ALL_TENSOR = 0
|
||||||
|
NO_TENSOR = 1
|
||||||
|
CONTAIN_TENSOR = 2
|
||||||
|
ALL_SCALAR = 3
|
||||||
|
|
||||||
|
INT_ = 0
|
||||||
|
BOOL_ = 1
|
||||||
|
UNSUPPORTED_DTYPE = 2
|
||||||
|
|
||||||
|
TENSOR_SETITEM = "tensor setitem"
|
||||||
|
TENSOR_GETITEM = "tensor getitem"
|
||||||
|
|
||||||
|
SET_ITEM_BY_ONE_TENSOR = 0
|
||||||
|
SET_ITEM_BY_TUPLE_OF_TENSOR = 1
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def check_equal(param1, param2, msg="{},{}"):
|
||||||
|
"""Checks whether the two parameters are equal or not."""
|
||||||
|
if param1 != param2:
|
||||||
|
raise ValueError(msg.format(param1, param2))
|
||||||
|
return param1
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
|
||||||
|
"""Checks the shape and size of the sensor and value."""
|
||||||
|
if data_shape == value_shape or data_size == value_size or value_size == 1:
|
||||||
|
return True
|
||||||
|
raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(value_shape, data_shape))
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def check_tensor_setitem_index(index, element_type=None):
|
||||||
|
"""Checks tuple index type of tensor assignment."""
|
||||||
|
if index is None:
|
||||||
|
raise IndexError("Tensor's index cannot be None.")
|
||||||
|
# eg. Tensor[Slice] = u
|
||||||
|
if isinstance(index, Slice):
|
||||||
|
return True
|
||||||
|
# eg. Tensor[tuple] = u
|
||||||
|
if isinstance(index, tuple):
|
||||||
|
if not index:
|
||||||
|
raise IndexError("Tensor's index cannot be empty.")
|
||||||
|
# eg. Tensor[tuple(Slice...)] = u
|
||||||
|
if isinstance(index[0], (Slice, Ellipsis_, int)):
|
||||||
|
return True
|
||||||
|
raise IndexError("Index of type '{}' is not supported yet.".format(type(index[0])))
|
||||||
|
# eg. Tensor[Tensor[dtype=bool]] = u
|
||||||
|
if isinstance(index, mstype.tensor_type):
|
||||||
|
if element_type is None or element_type != mstype.bool_:
|
||||||
|
raise TypeError(
|
||||||
|
"The index of tensor should be a bool type tensor. "
|
||||||
|
"{} type is not supported yet.".format(element_type))
|
||||||
|
return True
|
||||||
|
|
||||||
|
raise IndexError("Index of type '{}' is not supported yet.".format(type(index)))
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def is_same_type(inst, type_):
|
||||||
|
"""
|
||||||
|
Checks whether an object is an instance of a target type.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
inst (mindspore.dtype): Inspected type.
|
||||||
|
type_ (mindspore.dtype): Target type.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
bool, the check result.
|
||||||
|
"""
|
||||||
|
return inst == type_
|
||||||
|
|
||||||
|
|
||||||
|
def slice_expand(input_slices, shape):
|
||||||
|
"""
|
||||||
|
Converts slice to indices.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
|
||||||
|
shape (tuple): The shape of a sensor is an integer element tuple.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
tuple[list], This is expressed as (begins, ends, strides).
|
||||||
|
"""
|
||||||
|
begin = []
|
||||||
|
end = []
|
||||||
|
strides = []
|
||||||
|
index = 0
|
||||||
|
slices = None
|
||||||
|
# Slice or tuple(Slice...)
|
||||||
|
if isinstance(input_slices, Slice):
|
||||||
|
slices = (input_slices,)
|
||||||
|
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (Slice, Ellipsis_)):
|
||||||
|
is_have_ellipsis = False
|
||||||
|
for _, element in enumerate(input_slices):
|
||||||
|
if isinstance(element, Ellipsis_):
|
||||||
|
is_have_ellipsis = True
|
||||||
|
break
|
||||||
|
if is_have_ellipsis:
|
||||||
|
slices = ellipsis2slice(input_slices, shape)
|
||||||
|
else:
|
||||||
|
slices = input_slices
|
||||||
|
else:
|
||||||
|
raise IndexError("Tensor's index type is not supported yet.")
|
||||||
|
|
||||||
|
for s in slices:
|
||||||
|
start = 0 if (s.start is None) else s.start
|
||||||
|
stop = shape[index] if (s.end is None) else s.end
|
||||||
|
step = 1 if (s.step is None) else s.step
|
||||||
|
begin.append(start)
|
||||||
|
end.append(stop)
|
||||||
|
strides.append(step)
|
||||||
|
index += 1
|
||||||
|
while index < len(shape):
|
||||||
|
begin.append(0)
|
||||||
|
end.append(shape[index])
|
||||||
|
strides.append(1)
|
||||||
|
index += 1
|
||||||
|
return begin, end, strides
|
||||||
|
|
||||||
|
|
||||||
|
def ellipsis2slice(input_, shape):
|
||||||
|
"""Converts ellipsis to slice."""
|
||||||
|
input_slice = input_
|
||||||
|
result = []
|
||||||
|
if isinstance(input_, Ellipsis_):
|
||||||
|
input_slice = (input_,)
|
||||||
|
ell_count = 0
|
||||||
|
for _, element in enumerate(input_slice):
|
||||||
|
if not isinstance(element, Ellipsis_):
|
||||||
|
result.append(element)
|
||||||
|
continue
|
||||||
|
ell_count += 1
|
||||||
|
if ell_count > 1:
|
||||||
|
raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, "
|
||||||
|
"but it is currently {}".format(input_slice))
|
||||||
|
for _ in range(len(shape) - len(input_slice) + 1):
|
||||||
|
result.append(Slice(None, None, None))
|
||||||
|
return tuple(result)
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def slice2indices(input_slices, shape):
|
||||||
|
"""
|
||||||
|
Converts slice to indices.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
|
||||||
|
shape (tuple): The shape of a tensor is an integer element tuple.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, the shape is (n, 1).
|
||||||
|
"""
|
||||||
|
begin, end, strides = slice_expand(input_slices, shape)
|
||||||
|
np_r = []
|
||||||
|
for i, element in enumerate(shape):
|
||||||
|
s = begin[i] if (begin[i] >= 0) else (element + begin[i])
|
||||||
|
e = end[i] if (end[i] >= 0) else (element + end[i])
|
||||||
|
np_r.append(np.r_[s:e:strides[i]])
|
||||||
|
# Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape)
|
||||||
|
np_ix = np.ix_(*np_r)
|
||||||
|
ravel = np.ravel_multi_index(np_ix, shape)
|
||||||
|
ravel = Tensor(ravel.reshape(-1, 1), dtype=mstype.int32)
|
||||||
|
return ravel
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def check_indices(indices_size, index):
|
||||||
|
"""Checks indices whether is empty."""
|
||||||
|
if indices_size < 1:
|
||||||
|
raise IndexError("The tensor's index is unreasonable. index:{}".format(index))
|
||||||
|
return indices_size
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def check_indices_value_size(indices_size, value_size):
|
||||||
|
"""Checks if the sizes are already matched."""
|
||||||
|
if value_size < 1:
|
||||||
|
raise ValueError("The value assigned to tensor cannot be empty.")
|
||||||
|
if value_size > 1:
|
||||||
|
if value_size != indices_size:
|
||||||
|
raise ValueError(
|
||||||
|
"The value given to tensor does not match the index size,"
|
||||||
|
" value size:{}, indics size:{}".format(value_size, indices_size))
|
||||||
|
return value_size
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def integer_to_indices(index, shape):
|
||||||
|
"""Converts int or tuple[int] to indices."""
|
||||||
|
size = reduce(lambda x, y: x * y, shape)
|
||||||
|
range_ = np.arange(size).reshape(shape)
|
||||||
|
value = range_[index]
|
||||||
|
value = value.reshape(-1, 1)
|
||||||
|
return Tensor(value, dtype=mstype.int32)
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def tuple_element_is_slice(indexs):
|
||||||
|
"""Judges tuple element type."""
|
||||||
|
if not indexs:
|
||||||
|
raise IndexError("Tensor's index cannot be empty.")
|
||||||
|
if isinstance(indexs, tuple):
|
||||||
|
for _, ele in enumerate(indexs):
|
||||||
|
if not isinstance(ele, Slice):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def tuple_element_is_int(indexs):
|
||||||
|
"""Judges tuple element type."""
|
||||||
|
if not indexs:
|
||||||
|
raise IndexError("Tensor's index cannot be empty.")
|
||||||
|
if isinstance(indexs, tuple):
|
||||||
|
for _, ele in enumerate(indexs):
|
||||||
|
if not isinstance(ele, int):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def tuple_elements_type(types):
|
||||||
|
"""Judges the type of all elements of the tuple."""
|
||||||
|
tensors_number = 0
|
||||||
|
for ele in types:
|
||||||
|
if isinstance(ele, mstype.tensor_type):
|
||||||
|
tensors_number += 1
|
||||||
|
if tensors_number == len(types):
|
||||||
|
return ALL_TENSOR
|
||||||
|
if tensors_number == 0:
|
||||||
|
return NO_TENSOR
|
||||||
|
return CONTAIN_TENSOR
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def check_value_elements(data_dtype, types):
|
||||||
|
"""Judges the type of all elements of the tuple."""
|
||||||
|
tensors_number = 0
|
||||||
|
scalars_number = 0
|
||||||
|
for i, ele in enumerate(types):
|
||||||
|
if isinstance(ele, mstype.tensor_type):
|
||||||
|
ele_dtype = ele.element_type()
|
||||||
|
if data_dtype == ele_dtype:
|
||||||
|
tensors_number += 1
|
||||||
|
else:
|
||||||
|
raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' "
|
||||||
|
f"in value tuple is not consistent with origin tensor data type '{data_dtype}'.")
|
||||||
|
elif mstype.issubclass_(ele, data_dtype):
|
||||||
|
scalars_number += 1
|
||||||
|
else:
|
||||||
|
raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in "
|
||||||
|
f"value tuple is not consistent with origin tensor data type '{data_dtype}'.")
|
||||||
|
if tensors_number == len(types):
|
||||||
|
return ALL_TENSOR
|
||||||
|
if scalars_number == len(types):
|
||||||
|
return ALL_SCALAR
|
||||||
|
raise TypeError(f"For '{TENSOR_SETITEM}', the value does not support scalar and tensor mixing, but got {types}.")
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def get_index_tensor_dtype(dtype):
|
||||||
|
"""Check a tuple of tensor data type."""
|
||||||
|
if dtype == mstype.int32:
|
||||||
|
return INT_
|
||||||
|
if dtype == mstype.bool_:
|
||||||
|
return BOOL_
|
||||||
|
raise TypeError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.")
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def check_index_tensors_dtype(dtypes, op_name):
|
||||||
|
"""Check a tuple of tensor data type."""
|
||||||
|
if op_name == TENSOR_GETITEM:
|
||||||
|
valid_dtypes = (mstype.int32, mstype.int64)
|
||||||
|
elif op_name == TENSOR_SETITEM:
|
||||||
|
valid_dtypes = (mstype.int32,)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported operation.")
|
||||||
|
for ele in dtypes:
|
||||||
|
if ele in valid_dtypes and ele == dtypes[0]:
|
||||||
|
continue
|
||||||
|
raise TypeError(f"For '{op_name}', the index tensors data type must be same, "
|
||||||
|
f"and should be one of the following: {valid_dtypes}, but got {dtypes}.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def check_tensor_dtype_valid(dtype, valid_dtypes):
|
||||||
|
"""Check a tensor data type."""
|
||||||
|
if dtype in valid_dtypes:
|
||||||
|
return True
|
||||||
|
raise TypeError(f"The index tensor data type must be one of "
|
||||||
|
f"the following: {valid_dtypes}, but got {dtype}.")
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def check_tensors_dtype_same(x_dtype, y_dtype, op_name):
|
||||||
|
"""Check tensors data type same."""
|
||||||
|
if x_dtype == y_dtype:
|
||||||
|
return True
|
||||||
|
raise TypeError(f"For '{op_name}', the value data type '{y_dtype}' "
|
||||||
|
f"is not consistent with origin tensor data type {x_dtype}.")
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def broadcast_shapes(shapes, op_name):
|
||||||
|
"""Broadcasts a tuple of tensor."""
|
||||||
|
broadcast_shape = shapes[0]
|
||||||
|
for i, shape in enumerate(shapes):
|
||||||
|
logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.")
|
||||||
|
broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name)
|
||||||
|
return tuple(broadcast_shape)
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def check_two_shapes_need_broadcast(shape_x, shape_y):
|
||||||
|
"""Check two shapes need broadcast."""
|
||||||
|
error = ValueError(f"For 'tensor setitem with tensor', the value tensor shape "
|
||||||
|
f"{shape_y} could not broadcast the required updates shape {shape_x}.")
|
||||||
|
if len(shape_y) > len(shape_x):
|
||||||
|
raise error
|
||||||
|
for i in range(-len(shape_y), 0):
|
||||||
|
if shape_y[i] > shape_x[i]:
|
||||||
|
raise error
|
||||||
|
if shape_y[i] < shape_x[i] and shape_y[i] != 1:
|
||||||
|
raise error
|
||||||
|
if shape_y == shape_x:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def compute_multiples(origin_shape, broadcast_shape):
|
||||||
|
"""Compute multiples between broadcast_shape with origin_shape."""
|
||||||
|
len_gap = len(broadcast_shape) - len(origin_shape)
|
||||||
|
return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape))
|
||||||
|
|
||||||
|
|
||||||
|
def tile(broadcast_shape, x):
|
||||||
|
multiples = compute_multiples(F.shape(x), broadcast_shape)
|
||||||
|
return F.tile(x, multiples)
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def check_shapes_same(value_shapes, op_name):
|
||||||
|
"""Check if the shapes in the tuple are consistent."""
|
||||||
|
for i, shape in enumerate(value_shapes):
|
||||||
|
if shape != value_shapes[0]:
|
||||||
|
raise ValueError(f"For '{op_name}', the {i}th tensor shape in value tuple "
|
||||||
|
f"is not same as the first tensor shape.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type):
|
||||||
|
"""Convert a scalar to a tensor."""
|
||||||
|
if op_type == SET_ITEM_BY_ONE_TENSOR:
|
||||||
|
updates_shape = indices_shape + data_shape[1:]
|
||||||
|
else:
|
||||||
|
updates_shape = indices_shape[:-1] + data_shape[indices_shape[-1]:]
|
||||||
|
if isinstance(value, mstype.dtype_to_pytype(data_dtype)):
|
||||||
|
return Tensor(np.full(updates_shape, value), dtype=data_dtype)
|
||||||
|
raise TypeError(f"For '{TENSOR_SETITEM}', the value type '{value.__class__.__name__}'"
|
||||||
|
f" is not consistent with tensor data type {data_dtype}.")
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type):
|
||||||
|
"""Convert a tuple of scalar to a tensor."""
|
||||||
|
updates_shape = generate_updates_shape(data_shape, index_shape, op_type)
|
||||||
|
if len(value) != updates_shape[-1]:
|
||||||
|
raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} in the updates tuple "
|
||||||
|
f"does not meet the requirements: {updates_shape[-1]}.")
|
||||||
|
array = np.array(value, dtype=mstype.dtype_to_nptype(data_dtype))
|
||||||
|
reps = compute_multiples(updates_shape[-1:], updates_shape)
|
||||||
|
return Tensor(np.tile(array, reps))
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def generate_updates_shape(data_shape, index_shape, op_type):
|
||||||
|
"""Generate updates shape for 'tensor setitem'."""
|
||||||
|
if op_type == SET_ITEM_BY_ONE_TENSOR:
|
||||||
|
updates_shape = index_shape + data_shape[1:]
|
||||||
|
else:
|
||||||
|
updates_shape = index_shape[:-1] + data_shape[index_shape[-1]:]
|
||||||
|
return updates_shape
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def check_number_of_index_tensor(data_shape, tuple_len, op_name):
|
||||||
|
"""Check if the number of index tensor exceeds the dimension of the operated tensor."""
|
||||||
|
if tuple_len <= len(data_shape):
|
||||||
|
return True
|
||||||
|
raise IndexError(f"For '{op_name}', the number {tuple_len} of index tensor "
|
||||||
|
f"is greater than the dimension {len(data_shape)} of the operated tensor.")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_indeices_from_tuple_of_tensor(data, tuple_index, op_name):
|
||||||
|
"""Generate an indices tensor from a tuple of tensor."""
|
||||||
|
indices = None
|
||||||
|
check_index_tensor_number = check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name)
|
||||||
|
if check_index_tensor_number:
|
||||||
|
dtype_tuple = hyper_map(F.dtype, tuple_index)
|
||||||
|
check_dtypes = check_index_tensors_dtype(dtype_tuple, op_name)
|
||||||
|
if check_dtypes:
|
||||||
|
shape_tuple = hyper_map(F.shape, tuple_index)
|
||||||
|
broadcast_shape = broadcast_shapes(shape_tuple, op_name)
|
||||||
|
broadcast_tensors = hyper_map(F.partial(tile, broadcast_shape), tuple_index)
|
||||||
|
indices = pack(broadcast_tensors)
|
||||||
|
return indices
|
||||||
|
|
||||||
|
|
||||||
|
def generate_updates_from_scalar(data, indices, value, op_type):
|
||||||
|
"""Generate an updates tensor from a scalar."""
|
||||||
|
data_shape = F.shape(data)
|
||||||
|
indices_shape = F.shape(indices)
|
||||||
|
data_dtype = F.dtype(data)
|
||||||
|
return convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_updates_from_tuple(data, index, value, op_type):
|
||||||
|
"""Generate an updates tensor from a tuple."""
|
||||||
|
value_types = hyper_map(F.typeof, value)
|
||||||
|
data_dtype = F.dtype(data)
|
||||||
|
value_elements_type = check_value_elements(data_dtype, value_types)
|
||||||
|
if value_elements_type == ALL_TENSOR:
|
||||||
|
value_shapes = hyper_map(F.shape, value)
|
||||||
|
shapes_same = check_shapes_same(value_shapes, TENSOR_SETITEM)
|
||||||
|
if shapes_same:
|
||||||
|
value = F.pack(value)
|
||||||
|
return generate_updates_from_tensor(data, index, value, op_type)
|
||||||
|
|
||||||
|
data_shape = F.shape(data)
|
||||||
|
index_shape = F.shape(index)
|
||||||
|
return convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_updates_from_tensor(data, index, value, op_type):
|
||||||
|
"""Generate an updates tensor from a tensor."""
|
||||||
|
data_shape = F.shape(data)
|
||||||
|
index_shape = F.shape(index)
|
||||||
|
value_shape = F.shape(value)
|
||||||
|
data_dtype = F.dtype(data)
|
||||||
|
value_dtype = F.dtype(value)
|
||||||
|
updates_shape = value_shape
|
||||||
|
check_dtype_same = check_tensors_dtype_same(data_dtype, value_dtype, TENSOR_SETITEM)
|
||||||
|
if check_dtype_same:
|
||||||
|
updates_shape = generate_updates_shape(data_shape, index_shape, op_type)
|
||||||
|
need_broadcast = check_two_shapes_need_broadcast(updates_shape, value_shape)
|
||||||
|
if need_broadcast:
|
||||||
|
return tile(updates_shape, value)
|
||||||
|
return value
|
|
@ -15,9 +15,10 @@
|
||||||
|
|
||||||
"""Implementation for getitem."""
|
"""Implementation for getitem."""
|
||||||
|
|
||||||
from ...composite import base
|
from . import _utils as multi_utils
|
||||||
|
from ..import base
|
||||||
from ... import functional as F
|
from ... import functional as F
|
||||||
|
from ....common import dtype as mstype
|
||||||
|
|
||||||
getitem = base.MultitypeFuncGraph('getitem')
|
getitem = base.MultitypeFuncGraph('getitem')
|
||||||
"""
|
"""
|
||||||
|
@ -214,19 +215,45 @@ def _tensor_getitem_by_slice(data, slice_index):
|
||||||
return _tensor_slice(data, slice_index)
|
return _tensor_slice(data, slice_index)
|
||||||
|
|
||||||
|
|
||||||
|
@getitem.register("Tensor", "Tensor")
|
||||||
|
def _tensor_getitem_by_tensor(data, tensor_index):
|
||||||
|
"""
|
||||||
|
Getting item of tensor by slice.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
data (Tensor): A tensor.
|
||||||
|
tensor_index (Tensor): An index expressed by tensor.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, element type is same as the element type of data.
|
||||||
|
"""
|
||||||
|
check_dtypes = multi_utils.check_tensor_dtype_valid(F.dtype(tensor_index), (mstype.int32, mstype.int64))
|
||||||
|
result = None
|
||||||
|
if check_dtypes:
|
||||||
|
result = F.gather(data, tensor_index, 0)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@getitem.register("Tensor", "Tuple")
|
@getitem.register("Tensor", "Tuple")
|
||||||
def _tensor_getitem_by_slice_tuple(data, slice_tuple_index):
|
def _tensor_getitem_by_tuple(data, tuple_index):
|
||||||
"""
|
"""
|
||||||
Getting item of tensor by slice tuple.
|
Getting item of tensor by slice tuple.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
data (Tensor): A tensor.
|
data (Tensor): A tensor.
|
||||||
slice_tuple_index (tuple): Index in tuple.
|
tuple_index (tuple): Index in tuple.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Tensor, element type is same as the element type of data.
|
Tensor, element type is same as the element type of data.
|
||||||
"""
|
"""
|
||||||
return _tensor_slice(data, slice_tuple_index)
|
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
|
||||||
|
index_elements_type = multi_utils.tuple_elements_type(index_types)
|
||||||
|
result = None
|
||||||
|
if index_elements_type == multi_utils.NO_TENSOR:
|
||||||
|
result = _tensor_slice(data, tuple_index)
|
||||||
|
if index_elements_type == multi_utils.ALL_TENSOR:
|
||||||
|
result = _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@getitem.register("Tensor", "Ellipsis")
|
@getitem.register("Tensor", "Ellipsis")
|
||||||
|
@ -242,3 +269,10 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index):
|
||||||
Tensor, same as data.
|
Tensor, same as data.
|
||||||
"""
|
"""
|
||||||
return _tensor_slice(data, ellipsis_index)
|
return _tensor_slice(data, ellipsis_index)
|
||||||
|
|
||||||
|
|
||||||
|
def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
|
||||||
|
"""Tensor getitem by a tuple of tensor."""
|
||||||
|
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_GETITEM)
|
||||||
|
result = F.gather_nd(data, indices)
|
||||||
|
return result
|
||||||
|
|
|
@ -18,10 +18,11 @@
|
||||||
from ...composite import base
|
from ...composite import base
|
||||||
from ....common import dtype as mstype
|
from ....common import dtype as mstype
|
||||||
from ... import functional as F
|
from ... import functional as F
|
||||||
from . import _multitype_ops_util as mult_util
|
from . import _utils as multi_utils
|
||||||
|
|
||||||
setitem = base.MultitypeFuncGraph('setitem')
|
setitem = base.MultitypeFuncGraph('setitem')
|
||||||
|
|
||||||
|
|
||||||
@setitem.register("List", "Number", "String")
|
@setitem.register("List", "Number", "String")
|
||||||
def _list_setitem_with_string(data, number_index, value):
|
def _list_setitem_with_string(data, number_index, value):
|
||||||
"""
|
"""
|
||||||
|
@ -118,7 +119,7 @@ def _dict_setitem_with_number(data, key, value):
|
||||||
|
|
||||||
|
|
||||||
@setitem.register("Tensor", "Tensor", "Tensor")
|
@setitem.register("Tensor", "Tensor", "Tensor")
|
||||||
def _tensor_setitem_by_tensor_v1(data, index, value_tensor):
|
def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
|
||||||
"""
|
"""
|
||||||
Tensor assignment.
|
Tensor assignment.
|
||||||
|
|
||||||
|
@ -137,27 +138,15 @@ def _tensor_setitem_by_tensor_v1(data, index, value_tensor):
|
||||||
Outputs:
|
Outputs:
|
||||||
Tensor, element type and shape is same as data.
|
Tensor, element type and shape is same as data.
|
||||||
"""
|
"""
|
||||||
result = None
|
|
||||||
index_dtype = F.dtype(index)
|
index_dtype = F.dtype(index)
|
||||||
index_shape = F.shape(index)
|
tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype)
|
||||||
check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype)
|
if tensor_dtype == multi_utils.INT_:
|
||||||
if check_result:
|
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
|
||||||
data_shape = F.shape(data)
|
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
|
||||||
data_shape = mult_util.check_equal(data_shape, index_shape,
|
|
||||||
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
|
|
||||||
size = F.size(value_tensor)
|
|
||||||
size = mult_util.check_equal(1, size,
|
|
||||||
"When assign value is a tensor, its size should be {}, but current size is {}.")
|
|
||||||
dtype = F.dtype(data)
|
|
||||||
u_cast = F.cast(value_tensor, dtype)
|
|
||||||
one_data = F.ones_like(data)
|
|
||||||
u = F.tensor_mul(one_data, u_cast)
|
|
||||||
result = F.select(index, u, data)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@setitem.register("Tensor", "Tensor", "Number")
|
@setitem.register("Tensor", "Tensor", "Number")
|
||||||
def _tensor_setitem_by_tensor_v2(data, index, value):
|
def _tensor_setitem_by_tensor_with_number(data, index, value):
|
||||||
"""
|
"""
|
||||||
Tensor assignment.
|
Tensor assignment.
|
||||||
|
|
||||||
|
@ -171,22 +160,128 @@ def _tensor_setitem_by_tensor_v2(data, index, value):
|
||||||
Inputs:
|
Inputs:
|
||||||
data (Tensor): Assigned tensor.
|
data (Tensor): Assigned tensor.
|
||||||
index (Tensor): Tensor of bool type.
|
index (Tensor): Tensor of bool type.
|
||||||
value_tensor (Number): Assignment value.
|
value (Number): Assignment value.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Tensor, element type and shape is same as data.
|
Tensor, element type and shape is same as data.
|
||||||
"""
|
"""
|
||||||
result = None
|
|
||||||
index_dtype = F.dtype(index)
|
index_dtype = F.dtype(index)
|
||||||
index_shape = F.shape(index)
|
tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype)
|
||||||
check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype)
|
if tensor_dtype == multi_utils.BOOL_:
|
||||||
if check_result:
|
return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value)
|
||||||
shape = F.shape(data)
|
return _tensor_setitem_by_int_tensor_with_scalar(data, index, value)
|
||||||
shape = mult_util.check_equal(
|
|
||||||
shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
|
|
||||||
dtype = F.dtype(data)
|
@setitem.register("Tensor", "Tuple", "Number")
|
||||||
u = F.fill(dtype, shape, value)
|
def _tensor_setitem_by_tuple_with_number(data, tuple_index, value):
|
||||||
result = F.select(index, u, data)
|
"""
|
||||||
|
Tensor assignment.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Syntax support: A[B, C, D] = u.
|
||||||
|
Restraint condition: 1) A is a Tensor, and B, C, D are index.
|
||||||
|
2) u is a scalar.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
data (Tensor): Assigned tensor.
|
||||||
|
index (Tuple): An index tuple.
|
||||||
|
value (Number): Assignment value.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, element type and shape is same as data.
|
||||||
|
"""
|
||||||
|
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
|
||||||
|
index_elements_type = multi_utils.tuple_elements_type(index_types)
|
||||||
|
result = None
|
||||||
|
if index_elements_type == multi_utils.NO_TENSOR:
|
||||||
|
result = _tensor_assgin_number(data, tuple_index, value)
|
||||||
|
if index_elements_type == multi_utils.ALL_TENSOR:
|
||||||
|
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
|
||||||
|
updates = multi_utils.generate_updates_from_scalar(data, indices, value,
|
||||||
|
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
||||||
|
result = F.scatter_nd_update(data, indices, updates)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@setitem.register("Tensor", "Tuple", "Tensor")
|
||||||
|
def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
||||||
|
"""
|
||||||
|
Tensor assignment.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Syntax support: A[B, C, D] = U.
|
||||||
|
Restraint condition: 1) A is a Tensor, and B, C, D are index Tensors.
|
||||||
|
2) U is a Tensor.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
data (Tensor): Assigned tensor.
|
||||||
|
index (Tuple): An index tuple.
|
||||||
|
value (Tensor): Assignment tensor, should has the same data type as 'data'.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, element type and shape is same as data.
|
||||||
|
"""
|
||||||
|
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
|
||||||
|
index_elements_type = multi_utils.tuple_elements_type(index_types)
|
||||||
|
result = None
|
||||||
|
if index_elements_type == multi_utils.NO_TENSOR:
|
||||||
|
result = _tensor_assgin_tensor(data, tuple_index, value)
|
||||||
|
if index_elements_type == multi_utils.ALL_TENSOR:
|
||||||
|
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
|
||||||
|
updates = multi_utils.generate_updates_from_tensor(data, indices, value,
|
||||||
|
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
||||||
|
result = F.scatter_nd_update(data, indices, updates)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@setitem.register("Tensor", "Tuple", "Tuple")
|
||||||
|
def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
|
||||||
|
"""
|
||||||
|
Tensor assignment.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Syntax support: A[B, C, D] = U.
|
||||||
|
Restraint condition: 1) A is a Tensor, and B, C, D are index Tensors.
|
||||||
|
2) A B and C could be broadcast.
|
||||||
|
3) U is a Tensor.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
data (Tensor): Assigned tensor.
|
||||||
|
index (Tuple): A tuple of tensor, these tensor could be broadcast.
|
||||||
|
value (Tensor): Assignment tensor, should has the same data type as 'data'.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, element type and shape is same as data.
|
||||||
|
"""
|
||||||
|
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
|
||||||
|
index_elements_type = multi_utils.tuple_elements_type(index_types)
|
||||||
|
result = None
|
||||||
|
if index_elements_type == multi_utils.ALL_TENSOR:
|
||||||
|
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
|
||||||
|
updates = multi_utils.generate_updates_from_tuple(data, indices, value,
|
||||||
|
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
||||||
|
result = F.scatter_nd_update(data, indices, updates)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@setitem.register("Tensor", "Tensor", "Tuple")
|
||||||
|
def _tensor_setitem_by_tensor_v2(data, index, value):
|
||||||
|
"""
|
||||||
|
Tensor assignment.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
data (Tensor): Assigned tensor.
|
||||||
|
index (Tensor): Tensor of bool type.
|
||||||
|
value (Tuple): Assignment value.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, element type and shape is same as data.
|
||||||
|
"""
|
||||||
|
index_dtype = F.dtype(index)
|
||||||
|
check_dtype = multi_utils.check_tensor_dtype_valid(index_dtype, (mstype.int32, mstype.int64))
|
||||||
|
result = None
|
||||||
|
if check_dtype:
|
||||||
|
result = _tensor_setitem_by_tensor_with_tuple(data, index, value)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@ -212,66 +307,6 @@ def _tensor_setitem_with_slice_v3(data, input_slice, value):
|
||||||
return _tensor_assgin_tensor(data, input_slice, value)
|
return _tensor_assgin_tensor(data, input_slice, value)
|
||||||
|
|
||||||
|
|
||||||
@setitem.register("Tensor", "Tuple", "Tensor")
|
|
||||||
def _tensor_setitem_with_slice_v4(data, input_slice, value):
|
|
||||||
"""
|
|
||||||
Tensor assignment.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
Syntax support: A[tuple(Slice)] = U, and A[tuple(Number)] = U
|
|
||||||
Restraint condition: A is a Tensor
|
|
||||||
Slice like "1:3, ::, :4:-1"
|
|
||||||
U is a Tensor(size=1) or Tensor(size>1)
|
|
||||||
|
|
||||||
Inputs:
|
|
||||||
data (Tensor): Assigned tensor.
|
|
||||||
input_slice (Union[tuple[Slice], tuple[Number]]): Slice expression.
|
|
||||||
value (Number): Assignment value.
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
Tensor, element type and shape is same as data.
|
|
||||||
"""
|
|
||||||
return _tensor_assgin_tensor(data, input_slice, value)
|
|
||||||
|
|
||||||
|
|
||||||
def _tensor_assgin_tensor(data, input_slice, value):
|
|
||||||
"""Assigns a tensor value to the tensor by slice."""
|
|
||||||
result = None
|
|
||||||
check_result = mult_util.check_tensor_setitem_index(input_slice)
|
|
||||||
if check_result:
|
|
||||||
data_shape = F.shape(data)
|
|
||||||
indices = mult_util.slice2indices(input_slice, data_shape)
|
|
||||||
is_tuple_int = mult_util.tuple_element_is_int(input_slice)
|
|
||||||
if is_tuple_int:
|
|
||||||
indices = mult_util.integer_to_indices(input_slice, data_shape)
|
|
||||||
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _tensor_indices_tensor(data, data_shape, index, indices, value):
|
|
||||||
"""Assigns a tensor value to the tensor."""
|
|
||||||
data_size = F.size(data)
|
|
||||||
data_dtype = F.dtype(data)
|
|
||||||
indices_size = F.size(indices)
|
|
||||||
indices_size = mult_util.check_indices(indices_size, index)
|
|
||||||
update = F.fill(mstype.int32, (indices_size,), 1)
|
|
||||||
condition_1d = F.scatter_nd(indices, update, (data_size,))
|
|
||||||
condition = F.reshape(condition_1d, data_shape)
|
|
||||||
condition = F.cast(condition, mstype.bool_)
|
|
||||||
value_fill = None
|
|
||||||
value_size = F.size(value)
|
|
||||||
|
|
||||||
value_size = mult_util.check_indices_value_size(indices_size, value_size)
|
|
||||||
if value_size == 1:
|
|
||||||
value_fill = F.fill(data_dtype, (indices_size,), 1)
|
|
||||||
value = F.cast(value, data_dtype)
|
|
||||||
value_fill = F.tensor_mul(value_fill, value)
|
|
||||||
elif value_size > 1:
|
|
||||||
value_fill = F.reshape(value, (indices_size,))
|
|
||||||
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
|
|
||||||
u = F.reshape(value_1d, data_shape)
|
|
||||||
return F.select(condition, u, data)
|
|
||||||
|
|
||||||
@setitem.register("Tensor", "Slice", "Number")
|
@setitem.register("Tensor", "Slice", "Number")
|
||||||
def _tensor_setitem_with_slice_v1(data, input_slice, value):
|
def _tensor_setitem_with_slice_v1(data, input_slice, value):
|
||||||
"""
|
"""
|
||||||
|
@ -294,63 +329,25 @@ def _tensor_setitem_with_slice_v1(data, input_slice, value):
|
||||||
return _tensor_assgin_number(data, input_slice, value)
|
return _tensor_assgin_number(data, input_slice, value)
|
||||||
|
|
||||||
|
|
||||||
@setitem.register("Tensor", "Tuple", "Number")
|
|
||||||
def _tensor_setitem_with_slice_v2(data, input_slice, value):
|
|
||||||
"""
|
|
||||||
Tensor assignment.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
Syntax support: A[tuple(Slice)] = u, and A[tuple(Number)] = u
|
|
||||||
Restraint condition: A is a Tensor.
|
|
||||||
Slice like "1:3, ::, :4:-1"
|
|
||||||
u is a scalar
|
|
||||||
|
|
||||||
Inputs:
|
|
||||||
data (Tensor): Assigned tensor.
|
|
||||||
input_slice (Union[tuple[Slice], tuple[Number]]): slice expression.
|
|
||||||
value (Number): Assignment value.
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
Tensor, element type and shape is same as data.
|
|
||||||
"""
|
|
||||||
return _tensor_assgin_number(data, input_slice, value)
|
|
||||||
|
|
||||||
|
|
||||||
def _tensor_assgin_number(data, input_slice, value):
|
def _tensor_assgin_number(data, input_slice, value):
|
||||||
"""Givens a scalar assign to tensor by slice"""
|
"""Givens a scalar assign to tensor by slice"""
|
||||||
check_result = mult_util.check_tensor_setitem_index(input_slice)
|
check_result = multi_utils.check_tensor_setitem_index(input_slice)
|
||||||
result = None
|
result = None
|
||||||
if check_result:
|
if check_result:
|
||||||
data_shape = F.shape(data)
|
data_shape = F.shape(data)
|
||||||
indices = mult_util.slice2indices(input_slice, data_shape)
|
indices = multi_utils.slice2indices(input_slice, data_shape)
|
||||||
is_tuple_int = mult_util.tuple_element_is_int(input_slice)
|
is_tuple_int = multi_utils.tuple_element_is_int(input_slice)
|
||||||
if is_tuple_int:
|
if is_tuple_int:
|
||||||
indices = mult_util.integer_to_indices(input_slice, data_shape)
|
indices = multi_utils.integer_to_indices(input_slice, data_shape)
|
||||||
result = _tensor_indices_number(data, data_shape, input_slice, indices, value)
|
result = _tensor_indices_number(data, data_shape, input_slice, indices, value)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _tensor_indices_number(data, data_shape, index, indices, value):
|
|
||||||
"""Assigns a scalar value to the tensor."""
|
|
||||||
data_size = F.size(data)
|
|
||||||
data_dtype = F.dtype(data)
|
|
||||||
indices_size = F.size(indices)
|
|
||||||
indices_size = mult_util.check_indices(indices_size, index)
|
|
||||||
update = F.fill(mstype.int32, (indices_size,), 1)
|
|
||||||
condition_1d = F.scatter_nd(indices, update, (data_size,))
|
|
||||||
condition = F.reshape(condition_1d, data_shape)
|
|
||||||
condition = F.cast(condition, mstype.bool_)
|
|
||||||
value_fill = F.fill(data_dtype, (indices_size,), value)
|
|
||||||
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
|
|
||||||
u = F.reshape(value_1d, data_shape)
|
|
||||||
return F.select(condition, u, data)
|
|
||||||
|
|
||||||
|
|
||||||
@setitem.register("Tensor", "Number", "Number")
|
@setitem.register("Tensor", "Number", "Number")
|
||||||
def _tensor_setitem_with_int_v1(data, index, value):
|
def _tensor_setitem_with_int_v1(data, index, value):
|
||||||
"""Syntax: A[1] = 3"""
|
"""Syntax: A[1] = 3"""
|
||||||
data_shape = F.shape(data)
|
data_shape = F.shape(data)
|
||||||
indices = mult_util.integer_to_indices(index, data_shape)
|
indices = multi_utils.integer_to_indices(index, data_shape)
|
||||||
return _tensor_indices_number(data, data_shape, index, indices, value)
|
return _tensor_indices_number(data, data_shape, index, indices, value)
|
||||||
|
|
||||||
|
|
||||||
|
@ -358,7 +355,7 @@ def _tensor_setitem_with_int_v1(data, index, value):
|
||||||
def _tensor_setitem_with_int_v2(data, index, value):
|
def _tensor_setitem_with_int_v2(data, index, value):
|
||||||
"""Syntax: A[1] = Tensor"""
|
"""Syntax: A[1] = Tensor"""
|
||||||
data_shape = F.shape(data)
|
data_shape = F.shape(data)
|
||||||
indices = mult_util.integer_to_indices(index, data_shape)
|
indices = multi_utils.integer_to_indices(index, data_shape)
|
||||||
return _tensor_indices_tensor(data, data_shape, index, indices, value)
|
return _tensor_indices_tensor(data, data_shape, index, indices, value)
|
||||||
|
|
||||||
|
|
||||||
|
@ -379,7 +376,7 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value):
|
||||||
data_size = F.size(data)
|
data_size = F.size(data)
|
||||||
value_shape = F.shape(value)
|
value_shape = F.shape(value)
|
||||||
value_size = F.size(value)
|
value_size = F.size(value)
|
||||||
check_result = mult_util.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
|
check_result = multi_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
|
||||||
if check_result:
|
if check_result:
|
||||||
if data_size == value_size:
|
if data_size == value_size:
|
||||||
result = F.reshape(value, data_shape)
|
result = F.reshape(value, data_shape)
|
||||||
|
@ -389,3 +386,108 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value):
|
||||||
param2 = F.cast(value, data_dtype)
|
param2 = F.cast(value, data_dtype)
|
||||||
result = F.tensor_mul(param1, param2)
|
result = F.tensor_mul(param1, param2)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _tensor_assgin_tensor(data, input_slice, value):
|
||||||
|
"""Assigns a tensor value to the tensor by slice."""
|
||||||
|
result = None
|
||||||
|
check_result = multi_utils.check_tensor_setitem_index(input_slice)
|
||||||
|
if check_result:
|
||||||
|
data_shape = F.shape(data)
|
||||||
|
indices = multi_utils.slice2indices(input_slice, data_shape)
|
||||||
|
is_tuple_int = multi_utils.tuple_element_is_int(input_slice)
|
||||||
|
if is_tuple_int:
|
||||||
|
indices = multi_utils.integer_to_indices(input_slice, data_shape)
|
||||||
|
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _tensor_indices_tensor(data, data_shape, index, indices, value):
|
||||||
|
"""Assigns a tensor value to the tensor."""
|
||||||
|
data_size = F.size(data)
|
||||||
|
data_dtype = F.dtype(data)
|
||||||
|
indices_size = F.size(indices)
|
||||||
|
indices_size = multi_utils.check_indices(indices_size, index)
|
||||||
|
update = F.fill(mstype.int32, (indices_size,), 1)
|
||||||
|
condition_1d = F.scatter_nd(indices, update, (data_size,))
|
||||||
|
condition = F.reshape(condition_1d, data_shape)
|
||||||
|
condition = F.cast(condition, mstype.bool_)
|
||||||
|
value_fill = None
|
||||||
|
value_size = F.size(value)
|
||||||
|
|
||||||
|
value_size = multi_utils.check_indices_value_size(indices_size, value_size)
|
||||||
|
if value_size == 1:
|
||||||
|
value_fill = F.fill(data_dtype, (indices_size,), 1)
|
||||||
|
value = F.cast(value, data_dtype)
|
||||||
|
value_fill = F.tensor_mul(value_fill, value)
|
||||||
|
elif value_size > 1:
|
||||||
|
value_fill = F.reshape(value, (indices_size,))
|
||||||
|
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
|
||||||
|
u = F.reshape(value_1d, data_shape)
|
||||||
|
return F.select(condition, u, data)
|
||||||
|
|
||||||
|
|
||||||
|
def _tensor_indices_number(data, data_shape, index, indices, value):
|
||||||
|
"""Assigns a scalar value to the tensor."""
|
||||||
|
data_size = F.size(data)
|
||||||
|
data_dtype = F.dtype(data)
|
||||||
|
indices_size = F.size(indices)
|
||||||
|
indices_size = multi_utils.check_indices(indices_size, index)
|
||||||
|
update = F.fill(mstype.int32, (indices_size,), 1)
|
||||||
|
condition_1d = F.scatter_nd(indices, update, (data_size,))
|
||||||
|
condition = F.reshape(condition_1d, data_shape)
|
||||||
|
condition = F.cast(condition, mstype.bool_)
|
||||||
|
value_fill = F.fill(data_dtype, (indices_size,), value)
|
||||||
|
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
|
||||||
|
u = F.reshape(value_1d, data_shape)
|
||||||
|
return F.select(condition, u, data)
|
||||||
|
|
||||||
|
|
||||||
|
def _tensor_setitem_by_tensor_with_tuple(data, index, value):
|
||||||
|
"""Set a tensor item by a tensor with a tuple."""
|
||||||
|
updates = multi_utils.generate_updates_from_tuple(data, index, value,
|
||||||
|
multi_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||||
|
result = F.scatter_update(data, index, updates)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _tensor_setitem_by_int_tensor_with_scalar(data, index, value):
|
||||||
|
"""Set a tensor item by a int tensor with a scalar."""
|
||||||
|
updates = multi_utils.generate_updates_from_scalar(data, index, value,
|
||||||
|
multi_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||||
|
return F.scatter_update(data, index, updates)
|
||||||
|
|
||||||
|
|
||||||
|
def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):
|
||||||
|
"""Set a tensor item by a bool tensor with a scalar."""
|
||||||
|
index_shape = F.shape(index)
|
||||||
|
shape = F.shape(data)
|
||||||
|
shape = multi_utils.check_equal(
|
||||||
|
shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
|
||||||
|
dtype = F.dtype(data)
|
||||||
|
u = F.fill(dtype, shape, value)
|
||||||
|
return F.select(index, u, data)
|
||||||
|
|
||||||
|
|
||||||
|
def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
|
||||||
|
"""Set a tensor item by a int tensor with a tensor."""
|
||||||
|
updates = multi_utils.generate_updates_from_tensor(data, index, value,
|
||||||
|
multi_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||||
|
return F.scatter_update(data, index, updates)
|
||||||
|
|
||||||
|
|
||||||
|
def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
|
||||||
|
"""Set a tensor item by a bool tensor with a tensor."""
|
||||||
|
index_shape = F.shape(index)
|
||||||
|
data_shape = F.shape(data)
|
||||||
|
data_shape = multi_utils.check_equal(data_shape, index_shape,
|
||||||
|
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
|
||||||
|
size = F.size(value)
|
||||||
|
size = multi_utils.check_equal(1, size,
|
||||||
|
"When assign value is a tensor, its size should be {}, but current size is {}.")
|
||||||
|
dtype = F.dtype(data)
|
||||||
|
u_cast = F.cast(value, dtype)
|
||||||
|
one_data = F.ones_like(data)
|
||||||
|
u = F.tensor_mul(one_data, u_cast)
|
||||||
|
result = F.select(index, u, data)
|
||||||
|
return result
|
||||||
|
|
|
@ -31,6 +31,7 @@ dtype = P.DType()
|
||||||
issubclass_ = P.IsSubClass()
|
issubclass_ = P.IsSubClass()
|
||||||
isinstance_ = P.IsInstance()
|
isinstance_ = P.IsInstance()
|
||||||
fill = P.Fill()
|
fill = P.Fill()
|
||||||
|
tile = P.Tile()
|
||||||
select = P.Select()
|
select = P.Select()
|
||||||
size = P.Size()
|
size = P.Size()
|
||||||
ones_like = P.OnesLike()
|
ones_like = P.OnesLike()
|
||||||
|
@ -70,6 +71,12 @@ scalar_cast = P.ScalarCast()
|
||||||
print_ = P.Print()
|
print_ = P.Print()
|
||||||
expand_dims = P.ExpandDims()
|
expand_dims = P.ExpandDims()
|
||||||
scatter_nd = P.ScatterNd()
|
scatter_nd = P.ScatterNd()
|
||||||
|
gather = P.GatherV2()
|
||||||
|
gather_nd = P.GatherNd()
|
||||||
|
scatter_update = P.ScatterUpdate()
|
||||||
|
scatter_nd_update = P.ScatterNdUpdate()
|
||||||
|
pack = P.Pack()
|
||||||
|
|
||||||
|
|
||||||
tuple_setitem = Primitive('tuple_setitem')
|
tuple_setitem = Primitive('tuple_setitem')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive('tuple_getitem')
|
||||||
|
|
|
@ -24,7 +24,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
||||||
Fill, GatherNd, GatherV2, InvertPermutation,
|
Fill, GatherNd, GatherV2, InvertPermutation,
|
||||||
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
|
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
|
||||||
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
|
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
|
||||||
SameTypeShape, ScatterMax,
|
SameTypeShape, ScatterMax, ScatterUpdate,
|
||||||
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
||||||
Shape, Size, Slice, Split,
|
Shape, Size, Slice, Split,
|
||||||
Squeeze, StridedSlice, Tile,
|
Squeeze, StridedSlice, Tile,
|
||||||
|
@ -193,6 +193,7 @@ __all__ = [
|
||||||
'Pad',
|
'Pad',
|
||||||
'MirrorPad',
|
'MirrorPad',
|
||||||
'GatherNd',
|
'GatherNd',
|
||||||
|
'ScatterUpdate',
|
||||||
'ScatterNdUpdate',
|
'ScatterNdUpdate',
|
||||||
'Floor',
|
'Floor',
|
||||||
'NMSWithMask',
|
'NMSWithMask',
|
||||||
|
|
|
@ -19,7 +19,7 @@ from ..._c_expression import signature_rw as sig_rw
|
||||||
from ..._c_expression import signature_kind as sig_kind
|
from ..._c_expression import signature_kind as sig_kind
|
||||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||||
from ..._checkparam import Validator as validator, Rel
|
from ..._checkparam import Validator as validator, Rel
|
||||||
from .._utils import _get_concat_offset
|
from .._utils import get_concat_offset
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
|
|
||||||
|
|
||||||
|
@ -136,7 +136,7 @@ class ConcatOffset(PrimitiveWithInfer):
|
||||||
axis = self.axis
|
axis = self.axis
|
||||||
x_shp = input_x['shape']
|
x_shp = input_x['shape']
|
||||||
x_type = input_x['dtype']
|
x_type = input_x['dtype']
|
||||||
offset, _, axis = _get_concat_offset(x_shp, x_type, axis, self.name)
|
offset, _, axis = get_concat_offset(x_shp, x_type, axis, self.name)
|
||||||
self.add_prim_attr('T', x_type[0].element_type())
|
self.add_prim_attr('T', x_type[0].element_type())
|
||||||
offset_values = []
|
offset_values = []
|
||||||
for i in range(len(x_shp)):
|
for i in range(len(x_shp)):
|
||||||
|
|
|
@ -24,16 +24,15 @@ import itertools
|
||||||
import numbers
|
import numbers
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ..._c_expression import signature_rw as sig_rw
|
|
||||||
from ..._c_expression import signature_kind as sig_kind
|
|
||||||
from ..._checkparam import Validator as validator
|
from ..._checkparam import Validator as validator
|
||||||
from ..._checkparam import Rel
|
from ..._checkparam import Rel
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from ...common.tensor import Tensor
|
from ...common.tensor import Tensor
|
||||||
from ..operations.math_ops import _infer_shape_reduce
|
from ..operations.math_ops import _infer_shape_reduce
|
||||||
from .._utils import _get_concat_offset
|
from .._utils import get_concat_offset
|
||||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||||
|
|
||||||
|
|
||||||
def _check_infer_attr_reduce(axis, keep_dims, prim_name):
|
def _check_infer_attr_reduce(axis, keep_dims, prim_name):
|
||||||
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
|
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
|
||||||
validator.check_value_type('axis', axis, [int, tuple], prim_name)
|
validator.check_value_type('axis', axis, [int, tuple], prim_name)
|
||||||
|
@ -931,7 +930,7 @@ class InvertPermutation(PrimitiveWithInfer):
|
||||||
z = [x_value[i] for i in range(len(x_value))]
|
z = [x_value[i] for i in range(len(x_value))]
|
||||||
z.sort()
|
z.sort()
|
||||||
|
|
||||||
y = [None]*len(x_value)
|
y = [None] * len(x_value)
|
||||||
for i, value in enumerate(x_value):
|
for i, value in enumerate(x_value):
|
||||||
validator.check_value_type("input[%d]" % i, value, [int], self.name)
|
validator.check_value_type("input[%d]" % i, value, [int], self.name)
|
||||||
validator.check(f'value', z[i], f'index', i, Rel.EQ, self.name)
|
validator.check(f'value', z[i], f'index', i, Rel.EQ, self.name)
|
||||||
|
@ -1111,6 +1110,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
|
||||||
>>> input_x = Tensor(np.random.rand(5))
|
>>> input_x = Tensor(np.random.rand(5))
|
||||||
>>> index, output = P.ArgMinWithValue()(input_x)
|
>>> index, output = P.ArgMinWithValue()(input_x)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, axis=0, keep_dims=False):
|
def __init__(self, axis=0, keep_dims=False):
|
||||||
"""init ArgMinWithValue"""
|
"""init ArgMinWithValue"""
|
||||||
|
@ -1352,7 +1352,7 @@ class Concat(PrimitiveWithInfer):
|
||||||
axis = self.axis
|
axis = self.axis
|
||||||
x_shp = input_x['shape']
|
x_shp = input_x['shape']
|
||||||
x_type = input_x['dtype']
|
x_type = input_x['dtype']
|
||||||
_, all_shp, _ = _get_concat_offset(x_shp, x_type, axis, self.name)
|
_, all_shp, _ = get_concat_offset(x_shp, x_type, axis, self.name)
|
||||||
self.add_prim_attr('T', x_type[0].element_type())
|
self.add_prim_attr('T', x_type[0].element_type())
|
||||||
self.add_prim_attr('inputNums', len(x_shp))
|
self.add_prim_attr('inputNums', len(x_shp))
|
||||||
ret_shp = x_shp[0].copy()
|
ret_shp = x_shp[0].copy()
|
||||||
|
@ -1376,15 +1376,13 @@ def _get_pack_shape(x_shape, x_type, axis, prim_name):
|
||||||
if axis < 0:
|
if axis < 0:
|
||||||
axis = axis + rank_base + 1
|
axis = axis + rank_base + 1
|
||||||
for i in range(1, N):
|
for i in range(1, N):
|
||||||
v = x_shape[i]
|
|
||||||
validator.check('len of x_shape[%d]' % i, len(v), 'len of rank_base', rank_base, Rel.EQ, prim_name)
|
|
||||||
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, prim_name, TypeError)
|
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, prim_name, TypeError)
|
||||||
for j in range(rank_base):
|
if x_shape[i] != x_shape[0]:
|
||||||
if v[j] != x_shape[0][j]:
|
raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not pack with first element")
|
||||||
raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not pack with first element")
|
|
||||||
out_shape.insert(axis, N)
|
out_shape.insert(axis, N)
|
||||||
return out_shape
|
return out_shape
|
||||||
|
|
||||||
|
|
||||||
class Pack(PrimitiveWithInfer):
|
class Pack(PrimitiveWithInfer):
|
||||||
r"""
|
r"""
|
||||||
Packs a list of tensors in specified axis.
|
Packs a list of tensors in specified axis.
|
||||||
|
@ -1831,7 +1829,7 @@ class DiagPart(PrimitiveWithInfer):
|
||||||
return x_type
|
return x_type
|
||||||
|
|
||||||
def infer_shape(self, x_shape):
|
def infer_shape(self, x_shape):
|
||||||
if len(x_shape)%2 != 0 or \
|
if len(x_shape) % 2 != 0 or \
|
||||||
not x_shape:
|
not x_shape:
|
||||||
raise ValueError(f"For \'{self.name}\' input rank must be non-zero and even, but got rank {len(x_shape)}, "
|
raise ValueError(f"For \'{self.name}\' input rank must be non-zero and even, but got rank {len(x_shape)}, "
|
||||||
f"with shapes {x_shape}")
|
f"with shapes {x_shape}")
|
||||||
|
@ -2004,6 +2002,49 @@ class GatherNd(PrimitiveWithInfer):
|
||||||
return x_dtype
|
return x_dtype
|
||||||
|
|
||||||
|
|
||||||
|
class ScatterUpdate(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Update tensor value by using input indices and value.
|
||||||
|
|
||||||
|
Using given values to update tensor value, along with the input indices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_locking (bool): Whether protect the assignment by a lock. Default: True.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
|
||||||
|
- **indices** (Tensor) - The index of input tensor.
|
||||||
|
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
|
||||||
|
and update.shape = indices.shape + input_x.shape[1:].
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, has the same shape and type as `input_x`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32))
|
||||||
|
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
|
||||||
|
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
|
||||||
|
>>> op = P.ScatterNdUpdate()
|
||||||
|
>>> output = op(input_x, indices, update)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, use_locking=True):
|
||||||
|
"""Init ScatterNdUpdate"""
|
||||||
|
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape, indices_shape, value_shape):
|
||||||
|
if indices_shape + x_shape[1:] != value_shape:
|
||||||
|
raise ValueError('Input value are not match with input indices.')
|
||||||
|
return x_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
|
||||||
|
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
|
||||||
|
args = {"x": x_dtype, "value": value_dtype}
|
||||||
|
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
|
||||||
|
return x_dtype
|
||||||
|
|
||||||
|
|
||||||
class ScatterNdUpdate(PrimitiveWithInfer):
|
class ScatterNdUpdate(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Update tensor value by using input indices and value.
|
Update tensor value by using input indices and value.
|
||||||
|
@ -2028,11 +2069,6 @@ class ScatterNdUpdate(PrimitiveWithInfer):
|
||||||
>>> op = P.ScatterNdUpdate()
|
>>> op = P.ScatterNdUpdate()
|
||||||
>>> output = op(input_x, indices, update)
|
>>> output = op(input_x, indices, update)
|
||||||
"""
|
"""
|
||||||
__mindspore_signature__ = (
|
|
||||||
('input_x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
|
|
||||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD),
|
|
||||||
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD)
|
|
||||||
)
|
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, use_locking=True):
|
def __init__(self, use_locking=True):
|
||||||
|
@ -2142,10 +2178,10 @@ class SpaceToDepth(PrimitiveWithInfer):
|
||||||
validator.check('x dimension', len(x_shape), '', 4, Rel.EQ)
|
validator.check('x dimension', len(x_shape), '', 4, Rel.EQ)
|
||||||
out_shape = copy.deepcopy(x_shape)
|
out_shape = copy.deepcopy(x_shape)
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
if out_shape[i+2] % self.block_size != 0:
|
if out_shape[i + 2] % self.block_size != 0:
|
||||||
raise ValueError(f'For \'{self.name}\' input shape[{i+2}] {out_shape[i+2]} should be '
|
raise ValueError(f'For \'{self.name}\' input shape[{i + 2}] {out_shape[i + 2]} should be '
|
||||||
f'fully divided by block_size {self.block_size}')
|
f'fully divided by block_size {self.block_size}')
|
||||||
out_shape[i+2] //= self.block_size
|
out_shape[i + 2] //= self.block_size
|
||||||
|
|
||||||
out_shape[1] *= self.block_size * self.block_size
|
out_shape[1] *= self.block_size * self.block_size
|
||||||
return out_shape
|
return out_shape
|
||||||
|
@ -2199,9 +2235,10 @@ class DepthToSpace(PrimitiveWithInfer):
|
||||||
validator.check('x dimension', len(x_shape), '', 4, Rel.EQ)
|
validator.check('x dimension', len(x_shape), '', 4, Rel.EQ)
|
||||||
out_shape = copy.deepcopy(x_shape)
|
out_shape = copy.deepcopy(x_shape)
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
out_shape[i+2] *= self.block_size
|
out_shape[i + 2] *= self.block_size
|
||||||
|
|
||||||
validator.check_integer('x_shape[1] % (block_size*block_size)', x_shape[1] % (self.block_size*self.block_size),
|
validator.check_integer('x_shape[1] % (block_size*block_size)',
|
||||||
|
x_shape[1] % (self.block_size * self.block_size),
|
||||||
0, Rel.EQ, self.name)
|
0, Rel.EQ, self.name)
|
||||||
out_shape[1] //= self.block_size * self.block_size
|
out_shape[1] //= self.block_size * self.block_size
|
||||||
return out_shape
|
return out_shape
|
||||||
|
@ -2251,6 +2288,7 @@ class SpaceToBatch(PrimitiveWithInfer):
|
||||||
[[[[1.]]], [[[2.]]], [[[3.]]], [[[4.]]]]
|
[[[[1.]]], [[[2.]]], [[[3.]]], [[[4.]]]]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, block_size, paddings):
|
def __init__(self, block_size, paddings):
|
||||||
"""Init SpaceToBatch"""
|
"""Init SpaceToBatch"""
|
||||||
|
@ -2271,12 +2309,12 @@ class SpaceToBatch(PrimitiveWithInfer):
|
||||||
validator.check_integer('rank of input_x', len(x_shape), 4, Rel.EQ, self.name)
|
validator.check_integer('rank of input_x', len(x_shape), 4, Rel.EQ, self.name)
|
||||||
out_shape = copy.deepcopy(x_shape)
|
out_shape = copy.deepcopy(x_shape)
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
padded = out_shape[i+2] + self.paddings[i][0] + \
|
padded = out_shape[i + 2] + self.paddings[i][0] + \
|
||||||
self.paddings[i][1]
|
self.paddings[i][1]
|
||||||
if padded % self.block_size != 0:
|
if padded % self.block_size != 0:
|
||||||
raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by '
|
raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by '
|
||||||
f'block_size {self.block_size}')
|
f'block_size {self.block_size}')
|
||||||
out_shape[i+2] = padded // self.block_size
|
out_shape[i + 2] = padded // self.block_size
|
||||||
out_shape[0] *= self.block_size * self.block_size
|
out_shape[0] *= self.block_size * self.block_size
|
||||||
return out_shape
|
return out_shape
|
||||||
|
|
||||||
|
@ -2319,6 +2357,7 @@ class BatchToSpace(PrimitiveWithInfer):
|
||||||
[[[[1., 2.], [3., 4.]]]]
|
[[[[1., 2.], [3., 4.]]]]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, block_size, crops):
|
def __init__(self, block_size, crops):
|
||||||
"""Init BatchToSpace"""
|
"""Init BatchToSpace"""
|
||||||
|
@ -2339,10 +2378,10 @@ class BatchToSpace(PrimitiveWithInfer):
|
||||||
validator.check('rank of input_x', len(x_shape), '', 4)
|
validator.check('rank of input_x', len(x_shape), '', 4)
|
||||||
out_shape = copy.deepcopy(x_shape)
|
out_shape = copy.deepcopy(x_shape)
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
x_block_prod = out_shape[i+2] * self.block_size
|
x_block_prod = out_shape[i + 2] * self.block_size
|
||||||
crops_sum = self.crops[i][0] + self.crops[i][1]
|
crops_sum = self.crops[i][0] + self.crops[i][1]
|
||||||
validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name)
|
validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name)
|
||||||
out_shape[i+2] = x_block_prod - crops_sum
|
out_shape[i + 2] = x_block_prod - crops_sum
|
||||||
block_size_prod = self.block_size * self.block_size
|
block_size_prod = self.block_size * self.block_size
|
||||||
if out_shape[0] % block_size_prod != 0:
|
if out_shape[0] % block_size_prod != 0:
|
||||||
raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by '
|
raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by '
|
||||||
|
|
|
@ -24,7 +24,7 @@ from ..._checkparam import Validator as validator
|
||||||
from ..._checkparam import Rel
|
from ..._checkparam import Rel
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from ...common.tensor import Tensor
|
from ...common.tensor import Tensor
|
||||||
from .._utils import _get_broadcast_shape
|
from .._utils import get_broadcast_shape
|
||||||
from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op
|
from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,7 +75,7 @@ class _BinaryOp(PrimitiveWithInfer):
|
||||||
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
||||||
|
|
||||||
def infer_shape(self, x_shape, y_shape):
|
def infer_shape(self, x_shape, y_shape):
|
||||||
return _get_broadcast_shape(x_shape, y_shape, self.name)
|
return get_broadcast_shape(x_shape, y_shape, self.name)
|
||||||
|
|
||||||
|
|
||||||
class _MathBinaryOp(_BinaryOp):
|
class _MathBinaryOp(_BinaryOp):
|
||||||
|
|
|
@ -27,9 +27,15 @@ class IdentityEC(IExectorComponent):
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
result_id = self.function[keyword.id] + '-' + self.inputs[keyword.id]
|
result_id = self.function[keyword.id] + '-' + self.inputs[keyword.id]
|
||||||
group = self.function[keyword.group] + '-' + self.inputs[keyword.group]
|
group = self.function[keyword.group] + '-' + self.inputs[keyword.group]
|
||||||
return {
|
ret = {
|
||||||
keyword.id: result_id,
|
keyword.id: result_id,
|
||||||
keyword.group: group,
|
keyword.group: group,
|
||||||
keyword.desc_inputs: self.inputs[keyword.desc_inputs],
|
keyword.desc_inputs: self.inputs[keyword.desc_inputs],
|
||||||
keyword.result: self.function[keyword.block](*self.inputs[keyword.desc_inputs])
|
keyword.result: self.function[keyword.block](*self.inputs[keyword.desc_inputs])
|
||||||
}
|
}
|
||||||
|
print("buxue------------------------------------------------")
|
||||||
|
print("inputs")
|
||||||
|
print(ret[keyword.desc_inputs])
|
||||||
|
print("outputs")
|
||||||
|
print(ret[keyword.result])
|
||||||
|
return ret
|
||||||
|
|
|
@ -1297,7 +1297,7 @@ raise_set = [
|
||||||
('ScatterNdUpdate', {
|
('ScatterNdUpdate', {
|
||||||
'block': (P.ScatterNdUpdate(), {'exception': TypeError}),
|
'block': (P.ScatterNdUpdate(), {'exception': TypeError}),
|
||||||
'desc_inputs': (Tensor(np.ones((2, 3), np.float32)),
|
'desc_inputs': (Tensor(np.ones((2, 3), np.float32)),
|
||||||
Tensor(np.ones((2, 2), np.int32)),
|
Tensor(np.ones((2, 2), np.float32)),
|
||||||
Tensor(np.ones((2,), np.float32))),
|
Tensor(np.ones((2,), np.float32))),
|
||||||
'desc_bprop': [[2, 3]]}),
|
'desc_bprop': [[2, 3]]}),
|
||||||
('Pack', {
|
('Pack', {
|
||||||
|
|
|
@ -16,13 +16,14 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor, Parameter
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import dtype as mstype
|
from mindspore import dtype as mstype
|
||||||
from mindspore.nn import Cell
|
from mindspore.nn import Cell
|
||||||
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
||||||
from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
||||||
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
|
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \
|
||||||
|
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
|
||||||
|
|
||||||
|
|
||||||
class NetWorkSlicePositive(Cell):
|
class NetWorkSlicePositive(Cell):
|
||||||
|
@ -145,6 +146,160 @@ class TensorAssignWithSlice(Cell):
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
class TensorIndexByOneTensor(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(TensorIndexByOneTensor, self).__init__()
|
||||||
|
self.const = Tensor(np.ones((5, 4, 7, 8)), mstype.int32)
|
||||||
|
|
||||||
|
def construct(self, x, index):
|
||||||
|
ret = x[index] + self.const
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class TensorIndexByTwoTensors(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(TensorIndexByTwoTensors, self).__init__()
|
||||||
|
self.const = Tensor(np.ones((3, 4, 5, 8)), mstype.int32)
|
||||||
|
|
||||||
|
def construct(self, x, index_0, index_1):
|
||||||
|
ret = x[index_0, index_1] + self.const
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class TensorIndexByThreeTensors(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(TensorIndexByThreeTensors, self).__init__()
|
||||||
|
self.const = Tensor(np.ones((5, 3, 4, 5)), mstype.int32)
|
||||||
|
|
||||||
|
def construct(self, x, index_0, index_1, index_2):
|
||||||
|
ret = x[index_0, index_1, index_2] + self.const
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class TensorSetItemByOneTensorWithNumber(Cell):
|
||||||
|
def __init__(self, value):
|
||||||
|
super(TensorSetItemByOneTensorWithNumber, self).__init__()
|
||||||
|
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||||
|
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def construct(self, index):
|
||||||
|
self.param[index] = self.value
|
||||||
|
ret = self.param + self.const
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class TensorSetItemByOneTensorWithTensor(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(TensorSetItemByOneTensorWithTensor, self).__init__()
|
||||||
|
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||||
|
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||||
|
|
||||||
|
def construct(self, index, value):
|
||||||
|
self.param[index] = value
|
||||||
|
ret = self.param + self.const
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class TensorSetItemByOneTensorWithTupleOfNumber(Cell):
|
||||||
|
def __init__(self, value):
|
||||||
|
super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__()
|
||||||
|
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||||
|
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def construct(self, index):
|
||||||
|
self.param[index] = self.value
|
||||||
|
ret = self.param + self.const
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class TensorSetItemByOneTensorWithTupleOfTensor(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__()
|
||||||
|
self.const = Tensor(np.ones((6, 3, 8)), mstype.float32)
|
||||||
|
self.param = Parameter(Tensor(np.arange(6*3*8).reshape((6, 3, 8)), mstype.float32), name="x")
|
||||||
|
|
||||||
|
def construct(self, index, value_0, value_1, value_2):
|
||||||
|
self.param[index] = (value_0, value_1, value_2)
|
||||||
|
ret = self.param + self.const
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class TensorSetItemByTensorsWithNumber(Cell):
|
||||||
|
def __init__(self, value):
|
||||||
|
super(TensorSetItemByTensorsWithNumber, self).__init__()
|
||||||
|
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||||
|
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def construct(self, index_0, index_1, index_2):
|
||||||
|
self.param[index_0, index_1, index_2] = self.value
|
||||||
|
ret = self.param + self.const
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class TensorSetItemByTensorsWithTensor(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(TensorSetItemByTensorsWithTensor, self).__init__()
|
||||||
|
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||||
|
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||||
|
|
||||||
|
def construct(self, index_0, index_1, index_2, value):
|
||||||
|
self.param[index_0, index_1, index_2] = value
|
||||||
|
ret = self.param + self.const
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class TensorSetItemByTensorsWithTensorNumberError(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(TensorSetItemByTensorsWithTensorNumberError, self).__init__()
|
||||||
|
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||||
|
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||||
|
|
||||||
|
def construct(self, index_0, index_1, index_2, index_3, value):
|
||||||
|
self.param[index_0, index_1, index_2, index_3] = value
|
||||||
|
ret = self.param + self.const
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class TensorSetItemByTensorsWithTupleOfNumber(Cell):
|
||||||
|
def __init__(self, value):
|
||||||
|
super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__()
|
||||||
|
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||||
|
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def construct(self, index_0, index_1, index_2):
|
||||||
|
self.param[index_0, index_1, index_2] = self.value
|
||||||
|
ret = self.param + self.const
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class TensorSetItemByTensorsWithTupleOfTensor(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__()
|
||||||
|
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||||
|
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||||
|
|
||||||
|
def construct(self, index_0, index_1, index_2, value_0, value_1, value_2):
|
||||||
|
self.param[index_0, index_1, index_2] = (value_0, value_1, value_2)
|
||||||
|
ret = self.param + self.const
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__()
|
||||||
|
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||||
|
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||||
|
|
||||||
|
def construct(self, index_0, index_1, index_2, value_0, value_1):
|
||||||
|
self.param[index_0, index_1, index_2] = (value_0, value_1)
|
||||||
|
ret = self.param + self.const
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_assign():
|
def test_tensor_assign():
|
||||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||||
net = TensorAssignWithSlice()
|
net = TensorAssignWithSlice()
|
||||||
|
@ -441,15 +596,206 @@ test_cases = [
|
||||||
'block': NetWorkSliceEllipsis(),
|
'block': NetWorkSliceEllipsis(),
|
||||||
'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))],
|
'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))],
|
||||||
}),
|
}),
|
||||||
|
('TensorIndexByOneTensor', {
|
||||||
|
'block': TensorIndexByOneTensor(),
|
||||||
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(6, size=(5, 4)), mstype.int32)],
|
||||||
|
}),
|
||||||
|
('TensorIndexByTwoTensors', {
|
||||||
|
'block': TensorIndexByTwoTensors(),
|
||||||
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)],
|
||||||
|
}),
|
||||||
|
('TensorIndexByThreeTensors', {
|
||||||
|
'block': TensorIndexByThreeTensors(),
|
||||||
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByOneTensorWithNumber', {
|
||||||
|
'block': TensorSetItemByOneTensorWithNumber(value=0.0),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByOneTensorWithTensor', {
|
||||||
|
'block': TensorSetItemByOneTensorWithTensor(),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32),
|
||||||
|
Tensor(np.zeros((4, 7, 8)), mstype.float32)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByOneTensorWithTupleOfNumber', {
|
||||||
|
'block': TensorSetItemByOneTensorWithTupleOfNumber(value=(0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7)),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByOneTensorWithTupleOfTensor', {
|
||||||
|
'block': TensorSetItemByOneTensorWithTupleOfTensor(),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(6, size=(5, 4)), mstype.int32),
|
||||||
|
Tensor(np.zeros((8,), np.float32)),
|
||||||
|
Tensor(np.ones((8,), np.float32)),
|
||||||
|
Tensor(np.ones((8,), np.float32) * 2)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByTensorsWithNumber', {
|
||||||
|
'block': TensorSetItemByTensorsWithNumber(value=0.0),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByTensorsWithTensor', {
|
||||||
|
'block': TensorSetItemByTensorsWithTensor(),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.zeros((4, 5)), mstype.float32)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByTensorsWithTupleOfNumber', {
|
||||||
|
'block': TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.1, 2.2, 3.3, 4.4)),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByTensorsWithTupleOfTensor', {
|
||||||
|
'block': TensorSetItemByTensorsWithTupleOfTensor(),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.zeros((4, 5)), mstype.float32),
|
||||||
|
Tensor(np.ones((4, 5)), mstype.float32),
|
||||||
|
Tensor(np.ones((4, 5)) * 2, mstype.float32)],
|
||||||
|
})
|
||||||
|
]
|
||||||
|
|
||||||
|
raise_error_set = [
|
||||||
|
('TensorIndexByOneTensorDtypeError', {
|
||||||
|
'block': (TensorIndexByOneTensor(), {'exception': TypeError}),
|
||||||
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)],
|
||||||
|
}),
|
||||||
|
('TensorIndexByTwoTensorsShapeError', {
|
||||||
|
'block': (TensorIndexByTwoTensors(), {'exception': ValueError}),
|
||||||
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)],
|
||||||
|
}),
|
||||||
|
('TensorIndexByTwoTensorsDtypeError', {
|
||||||
|
'block': (TensorIndexByTwoTensors(), {'exception': TypeError}),
|
||||||
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)],
|
||||||
|
}),
|
||||||
|
('TensorIndexByThreeTensorsShapeError', {
|
||||||
|
'block': (TensorIndexByThreeTensors(), {'exception': ValueError}),
|
||||||
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)],
|
||||||
|
}),
|
||||||
|
('TensorIndexByThreeTensorsDtypeError', {
|
||||||
|
'block': (TensorIndexByThreeTensors(), {'exception': TypeError}),
|
||||||
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64),
|
||||||
|
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByOneTensorWithNumberTypeError', {
|
||||||
|
'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByOneTensorWithTensorShapeError', {
|
||||||
|
'block': (TensorSetItemByOneTensorWithTensor(), {'exception': ValueError}),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32),
|
||||||
|
Tensor(np.zeros((6, 7, 8)), mstype.float32)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByOneTensorWithTensorDtypeError', {
|
||||||
|
'block': (TensorSetItemByOneTensorWithTensor(), {'exception': TypeError}),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32),
|
||||||
|
Tensor(np.zeros((6, 7, 8)), mstype.int32)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByOneTensorWithTupleOfNumberTypeError', {
|
||||||
|
'block': (TensorSetItemByOneTensorWithTupleOfNumber(value=(0, 1, 2, 3, 4, 5, 6, 7)), {'exception': TypeError}),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByOneTensorWithTupleOfNumberNumberError', {
|
||||||
|
'block': (TensorSetItemByOneTensorWithTupleOfNumber(value=(0.0, 1.1, 2.2)), {'exception': ValueError}),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByOneTensorWithTupleOfTensorDtyeError', {
|
||||||
|
'block': (TensorSetItemByOneTensorWithTupleOfTensor(), {'exception': TypeError}),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(6, size=(5, 4)), mstype.int32),
|
||||||
|
Tensor(np.zeros((8,), np.int32)),
|
||||||
|
Tensor(np.ones((8,), np.int32)),
|
||||||
|
Tensor(np.ones((8,), np.float32) * 2)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByTensorsWithNumberTypeError', {
|
||||||
|
'block': (TensorSetItemByTensorsWithNumber(value=0), {'exception': TypeError}),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByTensorsWithTensorShapeError', {
|
||||||
|
'block': (TensorSetItemByTensorsWithTensor(), {'exception': ValueError}),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.zeros((2, 5)), mstype.float32)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByTensorsWithTensorTypeError', {
|
||||||
|
'block': (TensorSetItemByTensorsWithTensor(), {'exception': TypeError}),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.zeros((4, 5)), mstype.int32)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByTensorsWithTensorNumberError', {
|
||||||
|
'block': (TensorSetItemByTensorsWithTensorNumberError(), {'exception': IndexError}),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(8, size=(1, 3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.zeros((2, 5)), mstype.float32)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByTensorsWithTupleOfNumberTypeError', {
|
||||||
|
'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0, 1, 2, 3, 4)), {'exception': TypeError}),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByTensorsWithTupleOfNumberNumberError', {
|
||||||
|
'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.0, 2.0, 3.0)), {'exception': ValueError}),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByTensorsWithTupleOfTensorNumberError', {
|
||||||
|
'block': (TensorSetItemByTensorsWithTupleOfTensorNumberError(), {'exception': ValueError}),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.zeros((4, 5)), mstype.float32),
|
||||||
|
Tensor(np.ones((4, 5)), mstype.float32)],
|
||||||
|
}),
|
||||||
|
('TensorSetItemByTensorsWithTupleOfTensorTypeError', {
|
||||||
|
'block': (TensorSetItemByTensorsWithTupleOfTensor(), {'exception': TypeError}),
|
||||||
|
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||||
|
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
|
||||||
|
Tensor(np.zeros((4, 5)), mstype.float32),
|
||||||
|
Tensor(np.ones((4, 5)), mstype.int32),
|
||||||
|
Tensor(np.ones((4, 5)) * 2, mstype.int32)],
|
||||||
|
})
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
|
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
|
||||||
def test_compile():
|
def test_exec():
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||||
return test_cases
|
return test_cases
|
||||||
|
|
||||||
|
|
||||||
|
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
|
||||||
|
def test_check_exception():
|
||||||
|
return raise_error_set
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_slice_reduce_out_of_bounds_neg():
|
def test_tensor_slice_reduce_out_of_bounds_neg():
|
||||||
class NetWork(Cell):
|
class NetWork(Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -26,7 +26,7 @@ from mindspore.ops import functional as F
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops._grad.grad_base import bprop_getters
|
from mindspore.ops._grad.grad_base import bprop_getters
|
||||||
from mindspore.ops._grad.grad_math_ops import binop_grad_common
|
from mindspore.ops._grad.grad_math_ops import binop_grad_common
|
||||||
from mindspore.ops._utils import _get_broadcast_shape
|
from mindspore.ops._utils import get_broadcast_shape
|
||||||
from mindspore.ops.primitive import PrimitiveWithInfer, prim_attr_register
|
from mindspore.ops.primitive import PrimitiveWithInfer, prim_attr_register
|
||||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ class MockSub(PrimitiveWithInfer):
|
||||||
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
||||||
|
|
||||||
def infer_shape(self, x_shape, y_shape):
|
def infer_shape(self, x_shape, y_shape):
|
||||||
return _get_broadcast_shape(x_shape, y_shape)
|
return get_broadcast_shape(x_shape, y_shape)
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype, y_dtype):
|
def infer_dtype(self, x_dtype, y_dtype):
|
||||||
return x_dtype
|
return x_dtype
|
||||||
|
|
Loading…
Reference in New Issue