!31375 [assistant][ops] Add Slice
Merge pull request !31375 from TR-nbu/Slice
This commit is contained in:
commit
645d9ba6fb
|
@ -33,7 +33,8 @@ std::map<string, std::vector<std::pair<string, size_t>>> AicpuOpAttrToInputMap =
|
|||
{prim::kPrimOneHot->name(), {{"depth", 1}}},
|
||||
{prim::kPrimConcat->name(), {{"axis", 0}}},
|
||||
{prim::kPrimTranspose->name(), {{"perm", 1}}},
|
||||
{prim::kPrimGather->name(), {{"axis", 2}}}};
|
||||
{prim::kPrimGather->name(), {{"axis", 2}}},
|
||||
{prim::kPrimSlice->name(), {{"begin", 1}, {"size", 2}}}};
|
||||
|
||||
bool GetAicpuOpAttrToInputInfo(const CNodePtr &kernel_node, std::vector<std::pair<string, size_t>> *info) {
|
||||
std::string op_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/slice_cpu_kernel.h"
|
||||
#include <complex>
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
#include "include/common/thread_pool.h"
|
||||
|
@ -42,9 +43,19 @@ void SliceCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
cnode_ptr_ = kernel_node;
|
||||
static const std::unordered_map<TypeId, int> type_size_map = {{kNumberTypeBool, sizeof(bool)},
|
||||
{kNumberTypeInt32, sizeof(int)},
|
||||
{kNumberTypeInt8, sizeof(int8_t)},
|
||||
{kNumberTypeInt16, sizeof(int16_t)},
|
||||
{kNumberTypeInt32, sizeof(int32_t)},
|
||||
{kNumberTypeInt64, sizeof(int64_t)},
|
||||
{kNumberTypeUInt8, sizeof(uint8_t)},
|
||||
{kNumberTypeUInt16, sizeof(uint16_t)},
|
||||
{kNumberTypeUInt32, sizeof(uint32_t)},
|
||||
{kNumberTypeUInt64, sizeof(uint64_t)},
|
||||
{kNumberTypeFloat32, sizeof(float)},
|
||||
{kNumberTypeFloat64, sizeof(double)}};
|
||||
{kNumberTypeFloat64, sizeof(double)},
|
||||
{kNumberTypeFloat16, sizeof(float16)},
|
||||
{kNumberTypeComplex64, sizeof(std::complex<float>)},
|
||||
{kNumberTypeComplex128, sizeof(std::complex<double>)}};
|
||||
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (input_shape.size() > DIMENSION_8D || input_shape.empty()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
|
@ -70,9 +81,7 @@ void SliceCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
TypeId dtype = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
auto size_pair = type_size_map.find(dtype);
|
||||
if (size_pair == type_size_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dtype of 'input_x' must be bool, int32, float32 or float64, but got "
|
||||
<< TypeIdToType(dtype)->ToString();
|
||||
MS_LOG(EXCEPTION) << "Slice supports type in type_size_map, but got " << TypeIdToType(dtype)->ToString();
|
||||
}
|
||||
data_size_ = size_pair->second;
|
||||
}
|
||||
|
@ -155,9 +164,9 @@ bool SliceCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, co
|
|||
for (size_t i = 0; i < begin.size(); ++i) {
|
||||
if (input_shape[i] < LongToSize(begin[i] + size[i])) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', slice shape can not be greater than origin shape. But in dimension i=" << i
|
||||
<< ", origin shape 'input_shape[i]' is " << input_shape[i] << " and slice shape is "
|
||||
<< LongToSize(begin[i] + size[i]);
|
||||
<< "', slice shape should be not greater than origin shape. But in dimension i=" << i
|
||||
<< ", origin shape 'input_shape[i]' is " << input_shape[i]
|
||||
<< " and slice shape 'LongToSize(begin[i] + size[i])' is " << LongToSize(begin[i] + size[i]);
|
||||
}
|
||||
}
|
||||
InitSliceParam(input_shape, begin, size);
|
||||
|
|
|
@ -31,7 +31,6 @@ class SliceCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
~SliceCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
|
@ -39,9 +38,19 @@ class SliceCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
|
@ -69,5 +78,4 @@ class SliceCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_
|
||||
|
|
|
@ -164,6 +164,7 @@ constexpr auto kMeshgrid = "Meshgrid";
|
|||
constexpr auto kScatterNdMax = "ScatterNdMax";
|
||||
constexpr auto kScatterNdMin = "ScatterNdMin";
|
||||
constexpr auto kCSRSparseMatrixToSparseTensor = "CSRSparseMatrixToSparseTensor";
|
||||
constexpr auto kSlice = "Slice";
|
||||
|
||||
// NN
|
||||
constexpr auto kFractionalMaxPool3DWithFixedKsize = "FractionalMaxPool3DWithFixedKsize";
|
||||
|
@ -361,7 +362,7 @@ GVAR_DEF(PrimitivePtr, kPrimComputeAccidentalHits, std::make_shared<Primitive>("
|
|||
GVAR_DEF(PrimitivePtr, kPrimCacheSwapTable, std::make_shared<Primitive>("CacheSwapTable"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimDynamicAssign, std::make_shared<Primitive>("DynamicAssign"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimPadAndShift, std::make_shared<Primitive>("PadAndShift"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSlice, std::make_shared<Primitive>("Slice"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSlice, std::make_shared<Primitive>(kSlice));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSliceGrad, std::make_shared<Primitive>("SliceGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSliceFusion, std::make_shared<Primitive>("SliceFusion"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTile, std::make_shared<Primitive>(kTile));
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -28,72 +28,117 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 3, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto input_shape = shape_map[kShape];
|
||||
auto min_shape = shape_map[kMinShape];
|
||||
auto max_shape = shape_map[kMaxShape];
|
||||
std::vector<std::vector<int64_t>> InferImplSliceFuncCalInputValue(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
std::vector<int64_t> tmp_input;
|
||||
std::vector<std::vector<int64_t>> input_values;
|
||||
// get begin and size value
|
||||
for (size_t i = 1; i <= 2; ++i) {
|
||||
std::vector<int64_t> tmp_input;
|
||||
for (size_t i = 1; i <= kInputIndex2; ++i) {
|
||||
auto input_value = input_args[i]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(input_value);
|
||||
if (input_value->isa<tensor::Tensor>()) {
|
||||
tmp_input = CheckAndConvertUtils::CheckTensorIntValue("slice args value", input_value, prim_name);
|
||||
tmp_input = CheckAndConvertUtils::CheckTensorIntValue("slice args value", input_value, primitive->name());
|
||||
} else if (input_value->isa<ValueTuple>()) {
|
||||
tmp_input = CheckAndConvertUtils::CheckTupleInt("slice args value", input_value, primitive->name());
|
||||
} else if (input_value->isa<ValueList>()) {
|
||||
tmp_input = CheckAndConvertUtils::CheckListInt("slice args value", input_value, primitive->name());
|
||||
} else {
|
||||
tmp_input = CheckAndConvertUtils::CheckTupleInt("slice args value", input_value, prim_name);
|
||||
MS_EXCEPTION(TypeError) << "For Slice, the begin and size must be Tuple or List.";
|
||||
}
|
||||
(void)input_values.emplace_back(tmp_input);
|
||||
input_values.emplace_back(tmp_input);
|
||||
}
|
||||
|
||||
auto begin_v = input_values[0];
|
||||
auto size_v = input_values[1];
|
||||
auto rank = input_shape.size();
|
||||
if (begin_v.size() != rank || size_v.size() != rank) {
|
||||
MS_LOG(EXCEPTION) << "For '" << prim_name
|
||||
<< "', the shape of 'input', 'begin' and 'size' must be equal, but got 'input' shape: " << rank
|
||||
<< ", 'begin' shape: " << begin_v.size() << ", 'size' shape: " << size_v.size() << ".";
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < size_v.size(); ++i) {
|
||||
if (size_v[i] == -1) {
|
||||
size_v[i] = input_shape[i] - begin_v[i];
|
||||
}
|
||||
}
|
||||
if (max_shape.empty() && min_shape.empty()) {
|
||||
return std::make_shared<abstract::Shape>(size_v);
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(size_v, min_shape, max_shape);
|
||||
return input_values;
|
||||
}
|
||||
|
||||
TypePtr SliceInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("slice_prim_infer", input_args.size(), kEqual, 3, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto input_x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto input_size_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape());
|
||||
auto input_x_shape = input_x_shape_map[kShape];
|
||||
auto input_x_shape_min = input_x_shape_map[kMinShape];
|
||||
auto input_x_shape_max = input_x_shape_map[kMaxShape];
|
||||
auto input_begin_value_ptr = input_args[kInputIndex1]->BuildValue();
|
||||
auto input_size_value_ptr = input_args[kInputIndex2]->BuildValue();
|
||||
auto input_size_shape = input_size_shape_map[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of input_x", SizeToLong(input_x_shape.size()), kGreaterThan, 0,
|
||||
prim_name);
|
||||
ShapeVector out_shape = {};
|
||||
ShapeVector out_shape_min;
|
||||
ShapeVector out_shape_max;
|
||||
if (input_x_shape[0] == 0) {
|
||||
MS_EXCEPTION(ValueError) << "For Slice, the input_x must hava value.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto x_type_map = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(x_type_map);
|
||||
auto x_dtype = x_type_map->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(x_dtype);
|
||||
std::set<TypePtr> template_types = {kTensorType};
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", x_dtype, template_types, prim_name);
|
||||
if (!input_x_shape_max.empty()) {
|
||||
out_shape_min = input_x_shape_min;
|
||||
out_shape_max = input_x_shape_max;
|
||||
} else {
|
||||
out_shape_min = input_x_shape;
|
||||
out_shape_max = input_x_shape;
|
||||
}
|
||||
if (input_begin_value_ptr->isa<AnyValue>() || input_size_value_ptr->isa<AnyValue>()) {
|
||||
if (input_size_value_ptr->isa<AnyValue>()) {
|
||||
if (input_size_shape[0] < 0) {
|
||||
MS_EXCEPTION(ValueError) << "For Slice, the size shape haven't support dynamic yet.";
|
||||
}
|
||||
for (size_t i = 0; i < LongToSize(input_size_shape[0]); i++) {
|
||||
out_shape.push_back(-1);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < input_size_shape.size(); i++) {
|
||||
out_shape.push_back(-1);
|
||||
}
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(out_shape, out_shape_min, out_shape_max);
|
||||
}
|
||||
auto input_values = InferImplSliceFuncCalInputValue(primitive, input_args);
|
||||
auto input_begin_value = input_values[0];
|
||||
auto input_size_value = input_values[1];
|
||||
auto rank = input_x_shape.size();
|
||||
if (input_begin_value.size() != rank || input_size_value.size() != rank) {
|
||||
MS_EXCEPTION(ValueError) << "For Slice, the shape of input|begin|size must be equal.";
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckPositiveVector("input_begin", input_begin_value, prim_name);
|
||||
bool is_dynamic = false;
|
||||
for (size_t i = 0; i < rank; ++i) {
|
||||
if (input_x_shape[i] < 0) {
|
||||
is_dynamic = true;
|
||||
continue;
|
||||
}
|
||||
if (input_begin_value[i] + input_size_value[i] > input_x_shape[i]) {
|
||||
MS_EXCEPTION(ValueError) << "For Slice, the sum of begin_shape[" << i << "] and size_shape[" << i
|
||||
<< "] must be no greater than input_x_shape[" << i << "].";
|
||||
}
|
||||
if (input_size_value[i] == -1) {
|
||||
input_size_value[i] = input_x_shape[i] - input_begin_value[i];
|
||||
}
|
||||
out_shape_min[i] = input_size_value[i];
|
||||
out_shape_max[i] = input_size_value[i];
|
||||
}
|
||||
if (!is_dynamic) {
|
||||
return std::make_shared<abstract::Shape>(input_size_value);
|
||||
} else {
|
||||
return std::make_shared<abstract::Shape>(input_size_value, out_shape_min, out_shape_max);
|
||||
}
|
||||
}
|
||||
|
||||
TypePtr SliceInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
return CheckAndConvertUtils::CheckSubClass("input_x", input_args[0]->BuildType(), {kTensorType}, primitive->name());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(Slice, BaseOperator);
|
||||
AbstractBasePtr SliceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(SliceInferShape(primitive, input_args), SliceInferType(primitive, input_args));
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputIndex3, prim_name);
|
||||
auto type = SliceInferType(primitive, input_args);
|
||||
auto shape = SliceInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Slice::get_begin() const {
|
||||
|
@ -106,8 +151,7 @@ std::vector<int64_t> Slice::get_size() const {
|
|||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_C(kNameSlice, Slice);
|
||||
|
||||
REGISTER_HOST_DEPENDS(kNameSlice, (std::set<int64_t>{1, 2}));
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Slice, prim::kPrimSlice, SliceInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -816,6 +816,27 @@ std::vector<int64_t> CheckAndConvertUtils::CheckTupleInt(const std::string &arg_
|
|||
return result;
|
||||
}
|
||||
|
||||
std::vector<int64_t> CheckAndConvertUtils::CheckListInt(const std::string &arg_name, const ValuePtr &attr,
|
||||
const std::string &prim_name) {
|
||||
std::vector<int64_t> result;
|
||||
MS_EXCEPTION_IF_NULL(attr);
|
||||
if (attr->isa<ValueList>()) {
|
||||
std::vector<ValuePtr> attr_vec = attr->cast<ValueListPtr>()->value();
|
||||
(void)std::transform(
|
||||
attr_vec.begin(), attr_vec.end(), std::back_inserter(result), [=](const ValuePtr &e) -> int64_t {
|
||||
if (!e->isa<Int64Imm>()) {
|
||||
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " << arg_name
|
||||
<< " must be a list with all Int elements, but got " << attr->ToString();
|
||||
}
|
||||
return GetValue<int64_t>(e);
|
||||
});
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " << arg_name
|
||||
<< " must be a list with all Int elements, but got " << attr->ToString() << ".";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void CheckAndConvertUtils::CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape) {
|
||||
*min_shape = (*min_shape).empty() ? shape : *min_shape;
|
||||
*max_shape = (*max_shape).empty() ? shape : *max_shape;
|
||||
|
|
|
@ -303,6 +303,8 @@ class MS_CORE_API CheckAndConvertUtils {
|
|||
const std::string &arg_name);
|
||||
static std::vector<int64_t> CheckTupleInt(const std::string &prim_name, const ValuePtr &attr,
|
||||
const std::string &arg_name);
|
||||
static std::vector<int64_t> CheckListInt(const std::string &prim_name, const ValuePtr &attr,
|
||||
const std::string &arg_name);
|
||||
static void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape);
|
||||
static int64_t GetAndCheckFormat(const ValuePtr &value);
|
||||
static size_t GetRemoveMonadAbsNum(const AbstractBasePtrList &abs_list);
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Slice op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
slice_op_info = AiCPURegOp("Slice") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "offsets", "required") \
|
||||
.input(2, "size", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.I32_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.I32_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.I32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.C64_Default, DataType.I32_Default, DataType.I32_Default, DataType.C64_Default) \
|
||||
.dtype_format(DataType.C128_Default, DataType.I32_Default, DataType.I32_Default, DataType.C128_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.I64_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.I64_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.I64_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I64_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.C64_Default, DataType.I64_Default, DataType.I64_Default, DataType.C64_Default) \
|
||||
.dtype_format(DataType.C128_Default, DataType.I64_Default, DataType.I64_Default, DataType.C128_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(slice_op_info)
|
||||
def _slice_aicpu():
|
||||
"""Slice AiCPU register"""
|
||||
return
|
|
@ -3123,11 +3123,36 @@ class Unstack(Primitive):
|
|||
validator.check_value_type("axis", axis, [int], self.name)
|
||||
|
||||
|
||||
class Slice(PrimitiveWithInfer):
|
||||
class Slice(Primitive):
|
||||
"""
|
||||
Slices a tensor in the specified shape.
|
||||
|
||||
Refer to :func:`mindspore.ops.slice` for more detail.
|
||||
Slice the tensor `input_x` in shape of `size` and starting at the location specified by `begin`,
|
||||
The slice `begin` represents the offset in each dimension of `input_x`,
|
||||
The slice `size` represents the size of the output tensor.
|
||||
|
||||
Note that `begin` is zero-based and `size` is one-based.
|
||||
|
||||
If `size[i]` is -1, all remaining elements in dimension i are included in the slice.
|
||||
This is equivalent to setting :math:`size[i] = input_x.shape(i) - begin[i]`.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor): The target tensor.
|
||||
The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
|
||||
- **begin** (Union[tuple, list]): The beginning of the slice. Only constant value(>=0) is allowed.
|
||||
- **size** (Union[tuple, list]): The size of the slice. Only constant value is allowed.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is : input `size`, the data type is the same as `input_x`.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
||||
Raises:
|
||||
TypeError: If `input_x` is not a Tensor.
|
||||
TypeError: If `begin` or `size` is neither tuple nor list.
|
||||
ValueError: If `input_x`, `begin` and `size` not be equal in size.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
@ -3159,64 +3184,6 @@ class Slice(PrimitiveWithInfer):
|
|||
"""Initialize slice"""
|
||||
self.init_prim_io_names(inputs=['x', 'begin', 'size'], outputs=['output'])
|
||||
|
||||
def __infer__(self, x, begin, size):
|
||||
x_shape = x['shape']
|
||||
x_shp_len = len(x_shape)
|
||||
begin_v, size_v = begin['value'], size['value']
|
||||
|
||||
if 'max_shape' in x and 'min_shape' in x:
|
||||
max_shape = x['max_shape']
|
||||
min_shape = x['min_shape']
|
||||
else:
|
||||
min_shape = x['shape']
|
||||
max_shape = x['shape']
|
||||
if begin_v is None or size_v is None:
|
||||
# if size_v is not None and begin_v is None, it should be also a dynamic output shape.
|
||||
if size_v is None:
|
||||
if size['shape'][0] < 0:
|
||||
raise ValueError(f"For '{self.name}', the size shape haven't support dynamic yet.")
|
||||
out_shape = [-1] * size['shape'][0]
|
||||
else:
|
||||
out_shape = [-1] * len(size_v)
|
||||
return {'shape': out_shape,
|
||||
'dtype': x['dtype'],
|
||||
'value': None,
|
||||
'min_shape': min_shape,
|
||||
'max_shape': max_shape}
|
||||
validator.check_valid_input('begin', begin['value'], self.name)
|
||||
validator.check_valid_input('size', size['value'], self.name)
|
||||
validator.check_value_type("input begin", begin_v, [tuple, list], self.name)
|
||||
validator.check_value_type("input size", size_v, [tuple, list], self.name)
|
||||
for key, value in zip(('begin', 'size'), (begin_v, size_v)):
|
||||
validator.check(f'len of {key}', len(value),
|
||||
'len x\'s dim', x_shp_len)
|
||||
size_v = list(size_v)
|
||||
is_dynamic = False
|
||||
for i in range(x_shp_len):
|
||||
validator.check_non_negative_int(begin_v[i], f'input begin[{i}]')
|
||||
if x_shape[i] == -1:
|
||||
is_dynamic = True
|
||||
continue
|
||||
if size_v[i] == -1:
|
||||
size_v[i] = x_shape[i] - begin_v[i]
|
||||
validator.check_positive_int(size_v[i], f'input size[{i}]')
|
||||
if x_shape[i] < begin_v[i] + size_v[i]:
|
||||
y = begin_v[i] + size_v[i]
|
||||
raise ValueError(f"For '{self.name}', the sliced shape can not be greater than origin shape, "
|
||||
f"but got sliced shape is {y}, and origin shape is {x_shape}.")
|
||||
if not is_dynamic:
|
||||
return {'shape': size_v,
|
||||
'dtype': x['dtype'],
|
||||
'value': None}
|
||||
if size_v[i] >= 0:
|
||||
min_shape[i] = size_v[i]
|
||||
max_shape[i] = size_v[i]
|
||||
return {'shape': size_v,
|
||||
'dtype': x['dtype'],
|
||||
'value': None,
|
||||
'min_shape': min_shape,
|
||||
'max_shape': max_shape}
|
||||
|
||||
|
||||
class Coalesce(Primitive):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue