!31375 [assistant][ops] Add Slice

Merge pull request !31375 from TR-nbu/Slice
This commit is contained in:
i-robot 2022-06-10 09:57:03 +00:00 committed by Gitee
commit 645d9ba6fb
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 237 additions and 125 deletions

View File

@ -33,7 +33,8 @@ std::map<string, std::vector<std::pair<string, size_t>>> AicpuOpAttrToInputMap =
{prim::kPrimOneHot->name(), {{"depth", 1}}},
{prim::kPrimConcat->name(), {{"axis", 0}}},
{prim::kPrimTranspose->name(), {{"perm", 1}}},
{prim::kPrimGather->name(), {{"axis", 2}}}};
{prim::kPrimGather->name(), {{"axis", 2}}},
{prim::kPrimSlice->name(), {{"begin", 1}, {"size", 2}}}};
bool GetAicpuOpAttrToInputInfo(const CNodePtr &kernel_node, std::vector<std::pair<string, size_t>> *info) {
std::string op_name = common::AnfAlgo::GetCNodeName(kernel_node);

View File

@ -15,6 +15,7 @@
*/
#include "plugin/device/cpu/kernel/slice_cpu_kernel.h"
#include <complex>
#include <algorithm>
#include <unordered_map>
#include "include/common/thread_pool.h"
@ -42,9 +43,19 @@ void SliceCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
cnode_ptr_ = kernel_node;
static const std::unordered_map<TypeId, int> type_size_map = {{kNumberTypeBool, sizeof(bool)},
{kNumberTypeInt32, sizeof(int)},
{kNumberTypeInt8, sizeof(int8_t)},
{kNumberTypeInt16, sizeof(int16_t)},
{kNumberTypeInt32, sizeof(int32_t)},
{kNumberTypeInt64, sizeof(int64_t)},
{kNumberTypeUInt8, sizeof(uint8_t)},
{kNumberTypeUInt16, sizeof(uint16_t)},
{kNumberTypeUInt32, sizeof(uint32_t)},
{kNumberTypeUInt64, sizeof(uint64_t)},
{kNumberTypeFloat32, sizeof(float)},
{kNumberTypeFloat64, sizeof(double)}};
{kNumberTypeFloat64, sizeof(double)},
{kNumberTypeFloat16, sizeof(float16)},
{kNumberTypeComplex64, sizeof(std::complex<float>)},
{kNumberTypeComplex128, sizeof(std::complex<double>)}};
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (input_shape.size() > DIMENSION_8D || input_shape.empty()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
@ -70,9 +81,7 @@ void SliceCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
TypeId dtype = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
auto size_pair = type_size_map.find(dtype);
if (size_pair == type_size_map.end()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dtype of 'input_x' must be bool, int32, float32 or float64, but got "
<< TypeIdToType(dtype)->ToString();
MS_LOG(EXCEPTION) << "Slice supports type in type_size_map, but got " << TypeIdToType(dtype)->ToString();
}
data_size_ = size_pair->second;
}
@ -155,9 +164,9 @@ bool SliceCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, co
for (size_t i = 0; i < begin.size(); ++i) {
if (input_shape[i] < LongToSize(begin[i] + size[i])) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', slice shape can not be greater than origin shape. But in dimension i=" << i
<< ", origin shape 'input_shape[i]' is " << input_shape[i] << " and slice shape is "
<< LongToSize(begin[i] + size[i]);
<< "', slice shape should be not greater than origin shape. But in dimension i=" << i
<< ", origin shape 'input_shape[i]' is " << input_shape[i]
<< " and slice shape 'LongToSize(begin[i] + size[i])' is " << LongToSize(begin[i] + size[i]);
}
}
InitSliceParam(input_shape, begin, size);

View File

@ -31,7 +31,6 @@ class SliceCpuKernelMod : public DeprecatedNativeCpuKernelMod {
~SliceCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
@ -39,9 +38,19 @@ class SliceCpuKernelMod : public DeprecatedNativeCpuKernelMod {
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
@ -69,5 +78,4 @@ class SliceCpuKernelMod : public DeprecatedNativeCpuKernelMod {
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_

View File

@ -164,6 +164,7 @@ constexpr auto kMeshgrid = "Meshgrid";
constexpr auto kScatterNdMax = "ScatterNdMax";
constexpr auto kScatterNdMin = "ScatterNdMin";
constexpr auto kCSRSparseMatrixToSparseTensor = "CSRSparseMatrixToSparseTensor";
constexpr auto kSlice = "Slice";
// NN
constexpr auto kFractionalMaxPool3DWithFixedKsize = "FractionalMaxPool3DWithFixedKsize";
@ -361,7 +362,7 @@ GVAR_DEF(PrimitivePtr, kPrimComputeAccidentalHits, std::make_shared<Primitive>("
GVAR_DEF(PrimitivePtr, kPrimCacheSwapTable, std::make_shared<Primitive>("CacheSwapTable"));
GVAR_DEF(PrimitivePtr, kPrimDynamicAssign, std::make_shared<Primitive>("DynamicAssign"));
GVAR_DEF(PrimitivePtr, kPrimPadAndShift, std::make_shared<Primitive>("PadAndShift"));
GVAR_DEF(PrimitivePtr, kPrimSlice, std::make_shared<Primitive>("Slice"));
GVAR_DEF(PrimitivePtr, kPrimSlice, std::make_shared<Primitive>(kSlice));
GVAR_DEF(PrimitivePtr, kPrimSliceGrad, std::make_shared<Primitive>("SliceGrad"));
GVAR_DEF(PrimitivePtr, kPrimSliceFusion, std::make_shared<Primitive>("SliceFusion"));
GVAR_DEF(PrimitivePtr, kPrimTile, std::make_shared<Primitive>(kTile));

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -28,72 +28,117 @@
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 3, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto input_shape = shape_map[kShape];
auto min_shape = shape_map[kMinShape];
auto max_shape = shape_map[kMaxShape];
std::vector<std::vector<int64_t>> InferImplSliceFuncCalInputValue(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
std::vector<int64_t> tmp_input;
std::vector<std::vector<int64_t>> input_values;
// get begin and size value
for (size_t i = 1; i <= 2; ++i) {
std::vector<int64_t> tmp_input;
for (size_t i = 1; i <= kInputIndex2; ++i) {
auto input_value = input_args[i]->BuildValue();
MS_EXCEPTION_IF_NULL(input_value);
if (input_value->isa<tensor::Tensor>()) {
tmp_input = CheckAndConvertUtils::CheckTensorIntValue("slice args value", input_value, prim_name);
tmp_input = CheckAndConvertUtils::CheckTensorIntValue("slice args value", input_value, primitive->name());
} else if (input_value->isa<ValueTuple>()) {
tmp_input = CheckAndConvertUtils::CheckTupleInt("slice args value", input_value, primitive->name());
} else if (input_value->isa<ValueList>()) {
tmp_input = CheckAndConvertUtils::CheckListInt("slice args value", input_value, primitive->name());
} else {
tmp_input = CheckAndConvertUtils::CheckTupleInt("slice args value", input_value, prim_name);
MS_EXCEPTION(TypeError) << "For Slice, the begin and size must be Tuple or List.";
}
(void)input_values.emplace_back(tmp_input);
input_values.emplace_back(tmp_input);
}
auto begin_v = input_values[0];
auto size_v = input_values[1];
auto rank = input_shape.size();
if (begin_v.size() != rank || size_v.size() != rank) {
MS_LOG(EXCEPTION) << "For '" << prim_name
<< "', the shape of 'input', 'begin' and 'size' must be equal, but got 'input' shape: " << rank
<< ", 'begin' shape: " << begin_v.size() << ", 'size' shape: " << size_v.size() << ".";
}
for (size_t i = 0; i < size_v.size(); ++i) {
if (size_v[i] == -1) {
size_v[i] = input_shape[i] - begin_v[i];
}
}
if (max_shape.empty() && min_shape.empty()) {
return std::make_shared<abstract::Shape>(size_v);
}
return std::make_shared<abstract::Shape>(size_v, min_shape, max_shape);
return input_values;
}
TypePtr SliceInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
(void)CheckAndConvertUtils::CheckInteger("slice_prim_infer", input_args.size(), kEqual, 3, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto input_x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto input_size_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape());
auto input_x_shape = input_x_shape_map[kShape];
auto input_x_shape_min = input_x_shape_map[kMinShape];
auto input_x_shape_max = input_x_shape_map[kMaxShape];
auto input_begin_value_ptr = input_args[kInputIndex1]->BuildValue();
auto input_size_value_ptr = input_args[kInputIndex2]->BuildValue();
auto input_size_shape = input_size_shape_map[kShape];
(void)CheckAndConvertUtils::CheckInteger("rank of input_x", SizeToLong(input_x_shape.size()), kGreaterThan, 0,
prim_name);
ShapeVector out_shape = {};
ShapeVector out_shape_min;
ShapeVector out_shape_max;
if (input_x_shape[0] == 0) {
MS_EXCEPTION(ValueError) << "For Slice, the input_x must hava value.";
}
MS_EXCEPTION_IF_NULL(input_args[0]);
auto x_type_map = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(x_type_map);
auto x_dtype = x_type_map->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(x_dtype);
std::set<TypePtr> template_types = {kTensorType};
return CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", x_dtype, template_types, prim_name);
if (!input_x_shape_max.empty()) {
out_shape_min = input_x_shape_min;
out_shape_max = input_x_shape_max;
} else {
out_shape_min = input_x_shape;
out_shape_max = input_x_shape;
}
if (input_begin_value_ptr->isa<AnyValue>() || input_size_value_ptr->isa<AnyValue>()) {
if (input_size_value_ptr->isa<AnyValue>()) {
if (input_size_shape[0] < 0) {
MS_EXCEPTION(ValueError) << "For Slice, the size shape haven't support dynamic yet.";
}
for (size_t i = 0; i < LongToSize(input_size_shape[0]); i++) {
out_shape.push_back(-1);
}
} else {
for (size_t i = 0; i < input_size_shape.size(); i++) {
out_shape.push_back(-1);
}
}
return std::make_shared<abstract::Shape>(out_shape, out_shape_min, out_shape_max);
}
auto input_values = InferImplSliceFuncCalInputValue(primitive, input_args);
auto input_begin_value = input_values[0];
auto input_size_value = input_values[1];
auto rank = input_x_shape.size();
if (input_begin_value.size() != rank || input_size_value.size() != rank) {
MS_EXCEPTION(ValueError) << "For Slice, the shape of input|begin|size must be equal.";
}
(void)CheckAndConvertUtils::CheckPositiveVector("input_begin", input_begin_value, prim_name);
bool is_dynamic = false;
for (size_t i = 0; i < rank; ++i) {
if (input_x_shape[i] < 0) {
is_dynamic = true;
continue;
}
if (input_begin_value[i] + input_size_value[i] > input_x_shape[i]) {
MS_EXCEPTION(ValueError) << "For Slice, the sum of begin_shape[" << i << "] and size_shape[" << i
<< "] must be no greater than input_x_shape[" << i << "].";
}
if (input_size_value[i] == -1) {
input_size_value[i] = input_x_shape[i] - input_begin_value[i];
}
out_shape_min[i] = input_size_value[i];
out_shape_max[i] = input_size_value[i];
}
if (!is_dynamic) {
return std::make_shared<abstract::Shape>(input_size_value);
} else {
return std::make_shared<abstract::Shape>(input_size_value, out_shape_min, out_shape_max);
}
}
TypePtr SliceInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
return CheckAndConvertUtils::CheckSubClass("input_x", input_args[0]->BuildType(), {kTensorType}, primitive->name());
}
} // namespace
MIND_API_OPERATOR_IMPL(Slice, BaseOperator);
AbstractBasePtr SliceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return abstract::MakeAbstract(SliceInferShape(primitive, input_args), SliceInferType(primitive, input_args));
MS_EXCEPTION_IF_NULL(primitive);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputIndex3, prim_name);
auto type = SliceInferType(primitive, input_args);
auto shape = SliceInferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
std::vector<int64_t> Slice::get_begin() const {
@ -106,8 +151,7 @@ std::vector<int64_t> Slice::get_size() const {
return GetValue<std::vector<int64_t>>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameSlice, Slice);
REGISTER_HOST_DEPENDS(kNameSlice, (std::set<int64_t>{1, 2}));
REGISTER_PRIMITIVE_EVAL_IMPL(Slice, prim::kPrimSlice, SliceInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -816,6 +816,27 @@ std::vector<int64_t> CheckAndConvertUtils::CheckTupleInt(const std::string &arg_
return result;
}
std::vector<int64_t> CheckAndConvertUtils::CheckListInt(const std::string &arg_name, const ValuePtr &attr,
const std::string &prim_name) {
std::vector<int64_t> result;
MS_EXCEPTION_IF_NULL(attr);
if (attr->isa<ValueList>()) {
std::vector<ValuePtr> attr_vec = attr->cast<ValueListPtr>()->value();
(void)std::transform(
attr_vec.begin(), attr_vec.end(), std::back_inserter(result), [=](const ValuePtr &e) -> int64_t {
if (!e->isa<Int64Imm>()) {
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " << arg_name
<< " must be a list with all Int elements, but got " << attr->ToString();
}
return GetValue<int64_t>(e);
});
} else {
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " << arg_name
<< " must be a list with all Int elements, but got " << attr->ToString() << ".";
}
return result;
}
void CheckAndConvertUtils::CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape) {
*min_shape = (*min_shape).empty() ? shape : *min_shape;
*max_shape = (*max_shape).empty() ? shape : *max_shape;

View File

@ -303,6 +303,8 @@ class MS_CORE_API CheckAndConvertUtils {
const std::string &arg_name);
static std::vector<int64_t> CheckTupleInt(const std::string &prim_name, const ValuePtr &attr,
const std::string &arg_name);
static std::vector<int64_t> CheckListInt(const std::string &prim_name, const ValuePtr &attr,
const std::string &arg_name);
static void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape);
static int64_t GetAndCheckFormat(const ValuePtr &value);
static size_t GetRemoveMonadAbsNum(const AbstractBasePtrList &abs_list);

View File

@ -0,0 +1,59 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Slice op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
slice_op_info = AiCPURegOp("Slice") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.input(1, "offsets", "required") \
.input(2, "size", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.F64_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.I32_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.I32_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.I32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.C64_Default, DataType.I32_Default, DataType.I32_Default, DataType.C64_Default) \
.dtype_format(DataType.C128_Default, DataType.I32_Default, DataType.I32_Default, DataType.C128_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.F64_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.I64_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.I64_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.I64_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default) \
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I64_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.C64_Default, DataType.I64_Default, DataType.I64_Default, DataType.C64_Default) \
.dtype_format(DataType.C128_Default, DataType.I64_Default, DataType.I64_Default, DataType.C128_Default) \
.get_op_info()
@op_info_register(slice_op_info)
def _slice_aicpu():
"""Slice AiCPU register"""
return

View File

@ -3123,11 +3123,36 @@ class Unstack(Primitive):
validator.check_value_type("axis", axis, [int], self.name)
class Slice(PrimitiveWithInfer):
class Slice(Primitive):
"""
Slices a tensor in the specified shape.
Refer to :func:`mindspore.ops.slice` for more detail.
Slice the tensor `input_x` in shape of `size` and starting at the location specified by `begin`,
The slice `begin` represents the offset in each dimension of `input_x`,
The slice `size` represents the size of the output tensor.
Note that `begin` is zero-based and `size` is one-based.
If `size[i]` is -1, all remaining elements in dimension i are included in the slice.
This is equivalent to setting :math:`size[i] = input_x.shape(i) - begin[i]`.
Inputs:
- **input_x** (Tensor): The target tensor.
The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
- **begin** (Union[tuple, list]): The beginning of the slice. Only constant value(>=0) is allowed.
- **size** (Union[tuple, list]): The size of the slice. Only constant value is allowed.
Outputs:
Tensor, the shape is : input `size`, the data type is the same as `input_x`.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Raises:
TypeError: If `input_x` is not a Tensor.
TypeError: If `begin` or `size` is neither tuple nor list.
ValueError: If `input_x`, `begin` and `size` not be equal in size.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -3159,64 +3184,6 @@ class Slice(PrimitiveWithInfer):
"""Initialize slice"""
self.init_prim_io_names(inputs=['x', 'begin', 'size'], outputs=['output'])
def __infer__(self, x, begin, size):
x_shape = x['shape']
x_shp_len = len(x_shape)
begin_v, size_v = begin['value'], size['value']
if 'max_shape' in x and 'min_shape' in x:
max_shape = x['max_shape']
min_shape = x['min_shape']
else:
min_shape = x['shape']
max_shape = x['shape']
if begin_v is None or size_v is None:
# if size_v is not None and begin_v is None, it should be also a dynamic output shape.
if size_v is None:
if size['shape'][0] < 0:
raise ValueError(f"For '{self.name}', the size shape haven't support dynamic yet.")
out_shape = [-1] * size['shape'][0]
else:
out_shape = [-1] * len(size_v)
return {'shape': out_shape,
'dtype': x['dtype'],
'value': None,
'min_shape': min_shape,
'max_shape': max_shape}
validator.check_valid_input('begin', begin['value'], self.name)
validator.check_valid_input('size', size['value'], self.name)
validator.check_value_type("input begin", begin_v, [tuple, list], self.name)
validator.check_value_type("input size", size_v, [tuple, list], self.name)
for key, value in zip(('begin', 'size'), (begin_v, size_v)):
validator.check(f'len of {key}', len(value),
'len x\'s dim', x_shp_len)
size_v = list(size_v)
is_dynamic = False
for i in range(x_shp_len):
validator.check_non_negative_int(begin_v[i], f'input begin[{i}]')
if x_shape[i] == -1:
is_dynamic = True
continue
if size_v[i] == -1:
size_v[i] = x_shape[i] - begin_v[i]
validator.check_positive_int(size_v[i], f'input size[{i}]')
if x_shape[i] < begin_v[i] + size_v[i]:
y = begin_v[i] + size_v[i]
raise ValueError(f"For '{self.name}', the sliced shape can not be greater than origin shape, "
f"but got sliced shape is {y}, and origin shape is {x_shape}.")
if not is_dynamic:
return {'shape': size_v,
'dtype': x['dtype'],
'value': None}
if size_v[i] >= 0:
min_shape[i] = size_v[i]
max_shape[i] = size_v[i]
return {'shape': size_v,
'dtype': x['dtype'],
'value': None,
'min_shape': min_shape,
'max_shape': max_shape}
class Coalesce(Primitive):
"""