forked from mindspore-Ecosystem/mindspore
!17435 add DynamicStitch and SearchSorted ops for aicpu
From: @yanzhenxiang2020 Reviewed-by: @wuxuejian,@liangchenghui Signed-off-by: @wuxuejian
This commit is contained in:
commit
83f68e3a33
|
@ -40,7 +40,7 @@ void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<
|
||||||
// For compatibility with the current framework
|
// For compatibility with the current framework
|
||||||
if (op_name == kPrint || op_name == kGetNext || op_name == kPack || op_name == kMeshgrid ||
|
if (op_name == kPrint || op_name == kGetNext || op_name == kPack || op_name == kMeshgrid ||
|
||||||
op_name == kStackInitOpName || op_name == kStackDestroyOpName || op_name == kStackPushOpName ||
|
op_name == kStackInitOpName || op_name == kStackDestroyOpName || op_name == kStackPushOpName ||
|
||||||
op_name == kStackPopOpName) {
|
op_name == kStackPopOpName || op_name == kDynamicStitch) {
|
||||||
AicpuMetadataInfoForSpecialNodes(kernel_node, kernel_info_list);
|
AicpuMetadataInfoForSpecialNodes(kernel_node, kernel_info_list);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -56,7 +56,8 @@ void AicpuMetadataInfoForSpecialNodes(const CNodePtr &kernel_node,
|
||||||
std::vector<TypeId> inputs_type{};
|
std::vector<TypeId> inputs_type{};
|
||||||
auto op_name = AnfAlgo::GetCNodeName(kernel_node);
|
auto op_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||||
if (op_name == kPrint || op_name == kPack || op_name == kMeshgrid || op_name == kStackInitOpName ||
|
if (op_name == kPrint || op_name == kPack || op_name == kMeshgrid || op_name == kStackInitOpName ||
|
||||||
op_name == kStackDestroyOpName || op_name == kStackPushOpName || op_name == kStackPopOpName) {
|
op_name == kStackDestroyOpName || op_name == kStackPushOpName || op_name == kStackPopOpName ||
|
||||||
|
op_name == kDynamicStitch) {
|
||||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||||
inputs_format.emplace_back(kOpFormat_DEFAULT);
|
inputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||||
|
|
|
@ -60,7 +60,10 @@ constexpr auto kDropout2D = "Dropout2D";
|
||||||
constexpr auto kDropout3D = "Dropout3D";
|
constexpr auto kDropout3D = "Dropout3D";
|
||||||
constexpr auto kMaskedSelect = "MaskedSelect";
|
constexpr auto kMaskedSelect = "MaskedSelect";
|
||||||
constexpr auto kMaskedSelectGrad = "MaskedSelectGrad";
|
constexpr auto kMaskedSelectGrad = "MaskedSelectGrad";
|
||||||
const std::set<std::string> kCustAiCpuKernelOps{kIdentity, kMaskedSelect, kMaskedSelectGrad};
|
constexpr auto kDynamicStitch = "DynamicStitch";
|
||||||
|
constexpr auto kSearchSorted = "SearchSorted";
|
||||||
|
const std::set<std::string> kCustAiCpuKernelOps{kIdentity, kMaskedSelect, kMaskedSelectGrad, kDynamicStitch,
|
||||||
|
kSearchSorted};
|
||||||
const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter,
|
const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter,
|
||||||
kPadAndShift, kDropout3D, kDropout2D};
|
kPadAndShift, kDropout3D, kDropout2D};
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,7 @@ constexpr auto kUniqueOpName = "Unique";
|
||||||
constexpr auto kMaskedSelectOpName = "MaskedSelect";
|
constexpr auto kMaskedSelectOpName = "MaskedSelect";
|
||||||
constexpr auto kComputeAccidentalHitsOpName = "ComputeAccidentalHits";
|
constexpr auto kComputeAccidentalHitsOpName = "ComputeAccidentalHits";
|
||||||
constexpr auto kCTCGreedyDecoderOpName = "CTCGreedyDecoder";
|
constexpr auto kCTCGreedyDecoderOpName = "CTCGreedyDecoder";
|
||||||
|
constexpr auto kDynamicStitchOpName = "DynamicStitch";
|
||||||
constexpr auto kFour2FiveOpName = "Four2Five";
|
constexpr auto kFour2FiveOpName = "Four2Five";
|
||||||
constexpr auto kFive2FourOpName = "Five2Four";
|
constexpr auto kFive2FourOpName = "Five2Four";
|
||||||
constexpr auto kConv3DOpName = "Conv3D";
|
constexpr auto kConv3DOpName = "Conv3D";
|
||||||
|
@ -560,9 +561,9 @@ const std::set<std::string> kHWSpecialFormatSet = {
|
||||||
|
|
||||||
const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};
|
const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};
|
||||||
|
|
||||||
const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName,
|
const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName,
|
||||||
kPadAndShiftOpName, kCTCGreedyDecoderOpName, kDropoutGenMaskOpName,
|
kPadAndShiftOpName, kCTCGreedyDecoderOpName, kDropoutGenMaskOpName,
|
||||||
kMaskedSelectOpName};
|
kMaskedSelectOpName, kDynamicStitchOpName};
|
||||||
|
|
||||||
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
|
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
|
||||||
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};
|
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};
|
||||||
|
|
|
@ -184,6 +184,8 @@ AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &p
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplCTCGreedyDecoder(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplCTCGreedyDecoder(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "abstract/utils.h"
|
#include "abstract/utils.h"
|
||||||
#include "abstract/param_validator.h"
|
#include "abstract/param_validator.h"
|
||||||
#include "utils/shape_utils.h"
|
#include "utils/shape_utils.h"
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace abstract {
|
namespace abstract {
|
||||||
|
@ -1140,5 +1141,59 @@ AbstractBasePtr InferImplMaskedSelect(const AnalysisEnginePtr &, const Primitive
|
||||||
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(y_shape, min_shape, max_shape));
|
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(y_shape, min_shape, max_shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
CheckAndConvertUtils::CheckInteger("input number", args_spec_list.size(), kEqual, 2, prim_name);
|
||||||
|
for (const auto &item : args_spec_list) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
|
|
||||||
|
// input0: indices
|
||||||
|
auto input_tuple = args_spec_list[0]->cast<abstract::AbstractSequeuePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(input_tuple);
|
||||||
|
auto indices = input_tuple->elements();
|
||||||
|
auto indices0 = indices[0]->cast<abstract::AbstractTensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(indices0);
|
||||||
|
auto indices0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices0->BuildShape())[kShape];
|
||||||
|
|
||||||
|
// input1: data
|
||||||
|
auto input_tuple_1 = args_spec_list[1]->cast<abstract::AbstractSequeuePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(input_tuple_1);
|
||||||
|
auto data = input_tuple_1->elements();
|
||||||
|
auto data0 = data[0]->cast<abstract::AbstractTensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(data0);
|
||||||
|
auto data0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(data0->BuildShape())[kShape];
|
||||||
|
if (indices.size() != data.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "The number of input[0] must be the same as input[0]!";
|
||||||
|
}
|
||||||
|
|
||||||
|
int indices_total_size = 0;
|
||||||
|
std::map<std::string, TypePtr> types;
|
||||||
|
types.emplace("data0", data0->BuildType());
|
||||||
|
for (size_t i = 1; i < data.size(); ++i) {
|
||||||
|
auto indicesi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices[i]->BuildShape())[kShape];
|
||||||
|
auto datai_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(data[i]->BuildShape())[kShape];
|
||||||
|
if (indicesi_shape.size() >= datai_shape.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "The rank of indices[i] must be < rank of data[i]!";
|
||||||
|
}
|
||||||
|
indices_total_size += indicesi_shape.size();
|
||||||
|
}
|
||||||
|
std::set<TypePtr> valid_types = ops::common_valid_types;
|
||||||
|
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
|
||||||
|
|
||||||
|
ShapeVector out_shape = {abstract::Shape::SHP_ANY};
|
||||||
|
for (size_t i = indices0_shape.size(); i < data0_shape.size(); ++i) {
|
||||||
|
out_shape.push_back(data0_shape[i]);
|
||||||
|
}
|
||||||
|
const size_t EXPAND_MAX = 10;
|
||||||
|
ShapeVector min_shape = out_shape;
|
||||||
|
ShapeVector max_shape = out_shape;
|
||||||
|
min_shape[0] = 1;
|
||||||
|
max_shape[0] = indices_total_size * EXPAND_MAX;
|
||||||
|
return std::make_shared<AbstractTensor>(infer_type,
|
||||||
|
std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape));
|
||||||
|
}
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -96,6 +96,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
||||||
{prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, nullptr, true}},
|
{prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, nullptr, true}},
|
||||||
{prim::kPrimUpdateCache, {InferImplUpdateCache, nullptr, true}},
|
{prim::kPrimUpdateCache, {InferImplUpdateCache, nullptr, true}},
|
||||||
{prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, nullptr, true}},
|
{prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, nullptr, true}},
|
||||||
|
{prim::kPrimDynamicStitch, {InferImplDynamicStitch, nullptr, true}},
|
||||||
{prim::kPrimPadAndShift, {InferImplPadAndShift, nullptr, true}},
|
{prim::kPrimPadAndShift, {InferImplPadAndShift, nullptr, true}},
|
||||||
{prim::kPrimDynamicShape, {InferImplDynamicShape, nullptr, true}},
|
{prim::kPrimDynamicShape, {InferImplDynamicShape, nullptr, true}},
|
||||||
{prim::kPrimMapUniform, {InferImplMapUniform, nullptr, true}},
|
{prim::kPrimMapUniform, {InferImplMapUniform, nullptr, true}},
|
||||||
|
|
|
@ -278,6 +278,7 @@ inline const PrimitivePtr kPrimConv3DBackpropFilter = std::make_shared<Primitive
|
||||||
inline const PrimitivePtr kPrimCustomNormalize = std::make_shared<Primitive>("CustomNormalize");
|
inline const PrimitivePtr kPrimCustomNormalize = std::make_shared<Primitive>("CustomNormalize");
|
||||||
inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative");
|
inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative");
|
||||||
inline const PrimitivePtr kPrimCTCGreedyDecoder = std::make_shared<Primitive>("CTCGreedyDecoder");
|
inline const PrimitivePtr kPrimCTCGreedyDecoder = std::make_shared<Primitive>("CTCGreedyDecoder");
|
||||||
|
inline const PrimitivePtr kPrimDynamicStitch = std::make_shared<Primitive>("DynamicStitch");
|
||||||
inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter =
|
inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter =
|
||||||
std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter");
|
std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter");
|
||||||
inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput =
|
inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput =
|
||||||
|
|
|
@ -29,6 +29,7 @@ from .pad_and_shift import _pad_and_shift_aicpu
|
||||||
from .dropout_genmask import _dropout_genmask_aicpu
|
from .dropout_genmask import _dropout_genmask_aicpu
|
||||||
from .dropout2d import _dropout2d_aicpu
|
from .dropout2d import _dropout2d_aicpu
|
||||||
from .dropout3d import _dropout3d_aicpu
|
from .dropout3d import _dropout3d_aicpu
|
||||||
|
from .dynamic_stitch import _dynamic_stitch_aicpu
|
||||||
from .get_next import _get_next_aicpu
|
from .get_next import _get_next_aicpu
|
||||||
from .print_tensor import _print_aicpu
|
from .print_tensor import _print_aicpu
|
||||||
from .topk import _top_k_aicpu
|
from .topk import _top_k_aicpu
|
||||||
|
@ -39,6 +40,7 @@ from .squeeze import _squeeze_aicpu
|
||||||
from .expand_dims import _expand_dims_aicpu
|
from .expand_dims import _expand_dims_aicpu
|
||||||
from .randperm import _randperm_aicpu
|
from .randperm import _randperm_aicpu
|
||||||
from .random_choice_with_mask import _random_choice_with_mask_aicpu
|
from .random_choice_with_mask import _random_choice_with_mask_aicpu
|
||||||
|
from .search_sorted import _search_sorted_aicpu
|
||||||
from .stack import _stack_aicpu
|
from .stack import _stack_aicpu
|
||||||
from .uniform_candidate_sampler import _uniform_candidate_sampler_aicpu
|
from .uniform_candidate_sampler import _uniform_candidate_sampler_aicpu
|
||||||
from .log_uniform_candidate_sampler import _log_uniform_candidate_sampler_aicpu
|
from .log_uniform_candidate_sampler import _log_uniform_candidate_sampler_aicpu
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""DynamicStitch op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
dynamic_stitch_op_info = AiCPURegOp("DynamicStitch") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "indices", "dynamic") \
|
||||||
|
.input(1, "data", "dynamic") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(dynamic_stitch_op_info)
|
||||||
|
def _dynamic_stitch_aicpu():
|
||||||
|
"""DynamicStitch AiCPU register"""
|
||||||
|
return
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""SearchSorted op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
search_sorted_op_info = AiCPURegOp("SearchSorted") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.attr("out_int32", "bool") \
|
||||||
|
.attr("right", "bool") \
|
||||||
|
.input(0, "sequence", "required") \
|
||||||
|
.input(1, "values", "required") \
|
||||||
|
.output(0, "output", "required") \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(search_sorted_op_info)
|
||||||
|
def _search_sorted_aicpu():
|
||||||
|
"""SearchSorted AiCPU register"""
|
||||||
|
return
|
|
@ -1022,3 +1022,92 @@ class StackDestroy(PrimitiveWithInfer):
|
||||||
def __init__(self, index=1):
|
def __init__(self, index=1):
|
||||||
"""StackDestroy"""
|
"""StackDestroy"""
|
||||||
validator.check_value_type("index", index, [int], self.name)
|
validator.check_value_type("index", index, [int], self.name)
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicStitch(PrimitiveWithCheck):
|
||||||
|
r"""
|
||||||
|
Interleave the values from the data tensors into a single tensor.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **indices** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type.
|
||||||
|
- **data** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor. A stacked Tensor with the same type as `data`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the data types of elements in `data` or `indices` are not the same.
|
||||||
|
ValueError: If the length of `data` or `indices` is not greater than 1.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> x1 = Tensor([6], mstype.int32)
|
||||||
|
>>> x2 = Tensor(np.array([4, 1]), mstype.int32)
|
||||||
|
>>> x3 = Tensor(np.array([[5, 2], [0, 3]]), mstype.int32)
|
||||||
|
>>> y1 = Tensor(np.array([[6, 1]]), mstype.int32)
|
||||||
|
>>> y2 = Tensor(np.array([[41, 42], [11, 12]]), mstype.int32)
|
||||||
|
>>> y3 = Tensor(np.array([[[51, 52], [21, 22]], [[1, 2], [31, 32]]]), mstype.int32)
|
||||||
|
>>> stitch = ops.DynamicStitch()
|
||||||
|
>>> output = stitch([x1, x2, x3], [y1, y2, y3])
|
||||||
|
>>> print(output)
|
||||||
|
[[ 1 2]
|
||||||
|
[11 12]
|
||||||
|
[21 22]
|
||||||
|
[31 32]
|
||||||
|
[41 42]
|
||||||
|
[51 52]
|
||||||
|
[61 62]]
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize DynamicStitch"""
|
||||||
|
|
||||||
|
def check_shape(self, indices_shape, data_shape):
|
||||||
|
validator.check_value_type("shape of indices", indices_shape, [tuple, list], self.name)
|
||||||
|
validator.check_int(len(indices_shape), 1, Rel.GE, "len of indices_shape", self.name)
|
||||||
|
indices_dim0 = len(indices_shape[0])
|
||||||
|
indices_num = len(indices_shape)
|
||||||
|
|
||||||
|
validator.check_value_type("shape of data", data_shape, [tuple, list], self.name)
|
||||||
|
validator.check_int(len(data_shape), 1, Rel.GE, "len of data_shape", self.name)
|
||||||
|
data_dim0 = len(data_shape[0])
|
||||||
|
data_num = len(indices_shape)
|
||||||
|
|
||||||
|
validator.check("size of indices", indices_num, 'size of data', data_num, Rel.EQ, self.name)
|
||||||
|
|
||||||
|
# shape of `data` must start with shape of `indices`
|
||||||
|
for i in range(0, indices_num):
|
||||||
|
indices_dim = len(indices_shape[i])
|
||||||
|
data_dim = len(data_shape[i])
|
||||||
|
validator.check(f"dim of indices[{i}]", indices_dim, f"dim of data[{i}]", data_dim, Rel.LT, self.name)
|
||||||
|
if data_shape[i][:indices_dim] != data_shape[i][:indices_dim]:
|
||||||
|
raise ValueError(f"data[{i}].shape: {data_shape} does not start with indices[{i}].shape: {data_shape}")
|
||||||
|
|
||||||
|
# the last-(data_dim0-indices_dim0)-dim of data shape must end with same shape.
|
||||||
|
base_extra = data_dim0 - indices_dim0
|
||||||
|
for i in range(0, data_num):
|
||||||
|
indices_dim = len(indices_shape[i])
|
||||||
|
data_dim = len(data_shape[i])
|
||||||
|
extra = data_dim - indices_dim
|
||||||
|
validator.check(f"extra dim of data[{i}]", extra,
|
||||||
|
f"extra dim of data[0]", base_extra, Rel.EQ, self.name)
|
||||||
|
validator.check(f"data[0].shape[{indices_dim0}:]", data_shape[0][indices_dim0:],
|
||||||
|
f"data[{i}].shape[{len(indices_shape[i])}:]",
|
||||||
|
data_shape[i][indices_dim:], Rel.EQ, self.name)
|
||||||
|
|
||||||
|
out_shape = [-1] + data_shape[0][indices_dim0:]
|
||||||
|
return out_shape
|
||||||
|
|
||||||
|
def check_dtype(self, indices_type, data_type):
|
||||||
|
validator.check_subclass("indices[0]", indices_type[0], mstype.tensor, self.name)
|
||||||
|
validator.check_subclass("data[0]", data_type[0], mstype.tensor, self.name)
|
||||||
|
indices_num = len(indices_type)
|
||||||
|
for i in range(0, indices_num):
|
||||||
|
validator.check_tensor_dtype_valid(f'indices[{i}]', indices_type[i], mstype.int32, self.name)
|
||||||
|
validator.check_tensor_dtype_valid(f'data[{i}]', data_type[i],
|
||||||
|
mstype.number_type + (mstype.bool_,), self.name)
|
||||||
|
validator.check(f"type of data[{i}]", data_type[i], f"type of data[0]", data_type[0], Rel.EQ, self.name)
|
||||||
|
return data_type[0]
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import mindspore
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops.operations import _inner_ops as P
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.stitch = P.DynamicStitch()
|
||||||
|
|
||||||
|
def construct(self, indices, data):
|
||||||
|
return self.stitch(indices, data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_net_int32():
|
||||||
|
x1 = Tensor([6], mindspore.int32)
|
||||||
|
x2 = Tensor(np.array([4, 1]), mindspore.int32)
|
||||||
|
x3 = Tensor(np.array([[5, 2], [0, 3]]), mindspore.int32)
|
||||||
|
y1 = Tensor(np.array([[61, 62]]), mindspore.int32)
|
||||||
|
y2 = Tensor(np.array([[41, 42], [11, 12]]), mindspore.int32)
|
||||||
|
y3 = Tensor(np.array([[[51, 52], [21, 22]], [[1, 2], [31, 32]]]), mindspore.int32)
|
||||||
|
expected = np.array([[1, 2], [11, 12], [21, 22],
|
||||||
|
[31, 32], [41, 42], [51, 52], [61, 62]]).astype(np.int32)
|
||||||
|
|
||||||
|
print(x1.shape, x2.shape, x3.shape)
|
||||||
|
print(y1.shape, y2.shape, y3.shape)
|
||||||
|
indices = [x1, x2, x3]
|
||||||
|
data = [y1, y2, y3]
|
||||||
|
net = Net()
|
||||||
|
output = net(indices, data)
|
||||||
|
print(output.asnumpy())
|
||||||
|
assert np.array_equal(output.asnumpy(), expected)
|
|
@ -0,0 +1,42 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import mindspore
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, right=False, out_int32=True):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.search = P.SearchSorted(out_int32=out_int32, right=right)
|
||||||
|
|
||||||
|
def construct(self, sequence, values):
|
||||||
|
return self.search(sequence, values)
|
||||||
|
|
||||||
|
|
||||||
|
def test_net_int32():
|
||||||
|
np.random.seed(1)
|
||||||
|
input1 = np.sort(np.array(np.random.randint(10, size=(2, 3, 9)), dtype=np.int32), axis=-1)
|
||||||
|
sequence = Tensor(input1, mindspore.int32)
|
||||||
|
input2 = np.array(np.random.randint(10, size=(2, 3, 1)), dtype=np.int32)
|
||||||
|
values = Tensor(input2, mindspore.int32)
|
||||||
|
net = Net()
|
||||||
|
output = net(sequence, values)
|
||||||
|
print(output)
|
Loading…
Reference in New Issue