!49555 ops StridedSliceV2, StridedSliceV2Grad supports dynamic shape feature

Merge pull request !49555 from wang_ziqi/br_stridedslice_v2
This commit is contained in:
i-robot 2023-03-01 06:36:00 +00:00 committed by Gitee
commit 120f7b597e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 174 additions and 143 deletions

View File

@ -259,21 +259,24 @@ std::vector<bool> GradDec2Bin(const int64_t &mask) {
}
template <typename T>
void ParseStrideSliceGradMasksST(const CNodePtr &kernel_node, std::vector<T> *begin, std::vector<T> *end,
void ParseStrideSliceGradMasksST(const BaseOperatorPtr &base_operator, std::vector<T> *begin, std::vector<T> *end,
std::vector<T> *stride, ShapeVector *input_shape, const ShapeVector output_shape,
int shape_dim_output, int slice_len) {
std::vector<T> &_begin_attr = *begin;
std::vector<T> &_end_attr = *end;
std::vector<T> &_stride_attr = *stride;
auto begin_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrBeginMask);
auto prim = base_operator->GetPrim();
MS_EXCEPTION_IF_NULL(prim);
auto begin_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrBeginMask));
auto begin_mask = GradDec2Bin(begin_mask_int);
auto end_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrEndMask);
auto end_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrEndMask));
auto end_mask = GradDec2Bin(end_mask_int);
auto ellipsis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrEllipsisMask);
auto ellipsis_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrEllipsisMask));
auto ellipsis_mask = GradDec2Bin(ellipsis_mask_int);
auto new_axis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrNewAxisMask);
auto new_axis_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrNewAxisMask));
auto new_axis_mask = GradDec2Bin(new_axis_mask_int);
auto shrink_axis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrShrinkAxisMask);
auto shrink_axis_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrShrinkAxisMask));
auto shrink_axis_mask = GradDec2Bin(shrink_axis_mask_int);
int i = 0;
int j = 0;
@ -366,33 +369,54 @@ void FillEmptyDimsSTGrad(std::vector<T> *begin, std::vector<T> *end, std::vector
}
} // namespace
void StridedSliceV2GradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
cnode_ptr_ = kernel_node;
ClearVectors();
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex4);
if (input_shape.size() > kStridedSliceV2GradMaxInputShapeSize) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input tensor must be 8D or lower, but got "
<< input_shape.size() << "D.";
}
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex4);
dtype_grad_attr = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num == kStridedSliceV2GradDynamicInputsNum) { // Dynamic Shape
return;
}
// in the case that begin, end, size, stride are const value.
std::vector<int64_t> begin_me = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, BEGIN);
(void)std::transform(begin_me.begin(), begin_me.end(), std::back_inserter(begin_),
[](const int64_t &value) { return LongToInt(value); });
auto prim = common::AnfAlgo::GetCNodePrimitive(kernel_node);
MS_EXCEPTION_IF_NULL(prim);
auto strides = prim->GetAttr(STRIDES);
bool StridedSliceV2GradCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
base_operator_ = base_operator;
kernel_name_ = base_operator->name();
std::vector<int64_t> strides_me = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, STRIDES);
std::vector<int64_t> end_me = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, END);
if (inputs.empty()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input can not be empty.";
}
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match.first) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
dtype_ = inputs[kIndex4]->GetDtype();
dtype_grad_attr = inputs[kIndex1]->GetDtype();
return true;
}
int StridedSliceV2GradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
ClearVectors();
begin_shape_ = inputs_[kIndex1]->GetShapeVector();
end_shape_ = inputs_[kIndex2]->GetShapeVector();
stride_shape_ = inputs_[kIndex3]->GetShapeVector();
input_shape_ = inputs[kIndex4]->GetShapeVector();
output_shape_ = outputs[kIndex0]->GetShapeVector();
if (inputs.size() == kStridedSliceV2GradDynamicInputsNum) { // Dynamic Shape
return KRET_OK;
}
auto prim = base_operator->GetPrim();
MS_EXCEPTION_IF_NULL(prim);
std::vector<int64_t> strides_me = GetValue<std::vector<int64_t>>(prim->GetAttr(STRIDES));
std::vector<int64_t> end_me = GetValue<std::vector<int64_t>>(prim->GetAttr(END));
(void)std::transform(strides_me.begin(), strides_me.end(), std::back_inserter(strides_),
[](const int64_t &value) { return LongToInt(value); });
(void)std::transform(end_me.begin(), end_me.end(), std::back_inserter(end_),
@ -405,6 +429,7 @@ void StridedSliceV2GradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
}
FormatArgs(true);
ExpandAllMemberDims(kStridedSliceV2GradMaxInputShapeSize);
return KRET_OK;
}
void StridedSliceV2GradCpuKernelMod::ClearVectors() {
@ -412,8 +437,6 @@ void StridedSliceV2GradCpuKernelMod::ClearVectors() {
size_.clear();
strides_.clear();
end_.clear();
input_element_num_.clear();
output_element_num_.clear();
input_shape_.clear();
output_shape_.clear();
}
@ -449,33 +472,24 @@ void StridedSliceV2GradCpuKernelMod::ExpandAllMemberDims(size_t expand_dims) {
// init for dynamic shape
template <typename T>
void StridedSliceV2GradCpuKernelMod::InitParams(const std::vector<kernel::AddressPtr> &inputs) {
auto cnode = cnode_ptr_.lock();
ClearVectors();
output_shape_ = common::AnfAlgo::GetOutputInferShape(cnode, 0);
std::string kernel_name = common::AnfAlgo::GetCNodeName(cnode);
auto begin_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 1);
auto begin_ptr = static_cast<T *>(inputs[1]->addr);
std::vector<T> begin{begin_ptr, begin_ptr + begin_shape[0]};
auto end_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, kIndex2);
auto stride_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, kIndex3);
if (begin_shape.size() != 1 || end_shape.size() != 1 || stride_shape.size() != 1) {
if (begin_shape_.size() != 1 || end_shape_.size() != 1 || stride_shape_.size() != 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dimensions of 'begin', 'end', 'strides' must be 1, "
"but got the dimension of 'begin': "
<< begin_shape.size() << ", the dimension of 'end': " << end_shape.size()
<< ", and the dimension of 'strides': " << stride_shape.size();
<< begin_shape_.size() << ", the dimension of 'end': " << end_shape_.size()
<< ", and the dimension of 'strides': " << stride_shape_.size();
}
auto begin_ptr = static_cast<T *>(inputs[1]->addr);
std::vector<T> begin{begin_ptr, begin_ptr + begin_shape_[0]};
auto end_ptr = static_cast<T *>(inputs[kIndex2]->addr);
auto strides_ptr = static_cast<T *>(inputs[kIndex3]->addr);
std::vector<T> end{end_ptr, end_ptr + end_shape[0]};
std::vector<T> strides{strides_ptr, strides_ptr + stride_shape[0]};
input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, kIndex4);
std::vector<T> end{end_ptr, end_ptr + end_shape_[0]};
std::vector<T> strides{strides_ptr, strides_ptr + stride_shape_[0]};
shape_dim_output = SizeToInt(output_shape_.size());
slice_len = SizeToInt(begin.size());
FillEmptyDimsSTGrad<T>(&begin, &end, &strides, &input_shape_, &output_shape_);
ParseStrideSliceGradMasksST<T>(cnode, &begin, &end, &strides, &input_shape_, output_shape_, shape_dim_output,
ParseStrideSliceGradMasksST<T>(base_operator_, &begin, &end, &strides, &input_shape_, output_shape_, shape_dim_output,
slice_len);
FillEmptyDimsSTGrad<T>(&begin, &end, &strides, &input_shape_, &output_shape_);
(void)std::transform(begin.begin(), begin.end(), std::back_inserter(begin_), [](const T &value) { return value; });
@ -494,10 +508,6 @@ void StridedSliceV2GradCpuKernelMod::InitParams(const std::vector<kernel::Addres
bool StridedSliceV2GradCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.empty()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input can not be empty.";
}
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
bool ret = true;
if (dtype_ == kNumberTypeInt32) {
ret = LaunchKernel<int32_t>(inputs, outputs);

View File

@ -20,7 +20,7 @@
#include <memory>
#include <string>
#include <vector>
#include <map>
#include "plugin/device/cpu/kernel/nnacl/fp32_grad/strided_slice_grad.h"
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
@ -30,7 +30,7 @@ namespace kernel {
constexpr auto kStridedSliceV2Grad = "StridedSliceV2Grad";
constexpr auto kUnknown = "Unknown";
class StridedSliceV2GradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class StridedSliceV2GradCpuKernelMod : public NativeCpuKernelMod {
public:
StridedSliceV2GradCpuKernelMod() = default;
@ -38,8 +38,10 @@ class StridedSliceV2GradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
~StridedSliceV2GradCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
@ -68,16 +70,18 @@ class StridedSliceV2GradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
template <typename T>
bool CalStridedSliceV2Grad(T *input, T *output);
BaseOperatorPtr base_operator_;
std::vector<int> begin_;
std::vector<int> end_;
std::vector<int> strides_;
std::vector<int> size_;
ShapeVector input_shape_;
int shape_dim_output{0};
int slice_len{0};
std::vector<size_t> input_element_num_;
ShapeVector input_shape_;
ShapeVector begin_shape_;
ShapeVector end_shape_;
ShapeVector stride_shape_;
ShapeVector output_shape_;
std::vector<size_t> output_element_num_;
TypeId dtype_{kTypeUnknown};
TypeId dtype_grad_attr{kTypeUnknown};
std::string kernel_type_{kUnknown};

View File

@ -311,22 +311,22 @@ static std::map<std::string, std::vector<KernelAttr>> support_list_map = {
.AddOutputAttr(kNumberTypeComplex128)}}};
template <typename T>
void ParseStrideSliceMasksST(const CNodePtr &kernel_node, std::vector<T> *begin, std::vector<T> *end,
void ParseStrideSliceMasksST(const BaseOperatorPtr &base_operator, std::vector<T> *begin, std::vector<T> *end,
std::vector<T> *stride, const ShapeVector &input_shape, size_t shape_dim_input,
size_t slice_len) {
std::vector<T> &_begin_attr = *begin;
std::vector<T> &_end_attr = *end;
std::vector<T> &_stride_attr = *stride;
auto begin_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrBeginMask);
auto prim = base_operator->GetPrim();
auto begin_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrBeginMask));
auto begin_mask = Dec2Bin(begin_mask_int);
auto end_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrEndMask);
auto end_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrEndMask));
auto end_mask = Dec2Bin(end_mask_int);
auto ellipsis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrEllipsisMask);
auto ellipsis_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrEllipsisMask));
auto ellipsis_mask = Dec2Bin(ellipsis_mask_int);
auto new_axis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrNewAxisMask);
auto new_axis_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrNewAxisMask));
auto new_axis_mask = Dec2Bin(new_axis_mask_int);
auto shrink_axis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrShrinkAxisMask);
auto shrink_axis_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrShrinkAxisMask));
auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_int);
size_t i = 0;
size_t j = 0;
@ -380,14 +380,14 @@ void ParseStrideSliceMasksST(const CNodePtr &kernel_node, std::vector<T> *begin,
}
template <typename T>
void FillEmptyDimsST(const CNodePtr &kernel_node, std::vector<T> *begin, std::vector<T> *end, std::vector<T> *stride,
ShapeVector *input_shape) {
void FillEmptyDimsST(const BaseOperatorPtr &base_operator, std::vector<T> *begin, std::vector<T> *end,
std::vector<T> *stride, ShapeVector *input_shape) {
std::vector<T> &_begin = *begin;
std::vector<T> &_end = *end;
std::vector<T> &_stride = *stride;
auto &_input_shape = *input_shape;
if (_begin.size() != _end.size() || _begin.size() != _stride.size() || _begin.size() > _input_shape.size()) {
MS_LOG(EXCEPTION) << "For '" << common::AnfAlgo::GetCNodeName(kernel_node)
MS_LOG(EXCEPTION) << "For '" << base_operator->name()
<< "', the length of 'begin', 'stride' and 'end' should be equal "
"and less than or equal to the dimension of 'input_x', but got the length of 'begin': "
<< _begin.size() << ", the length of 'stride': " << _stride.size()
@ -416,27 +416,59 @@ void FillEmptyDimsST(const CNodePtr &kernel_node, std::vector<T> *begin, std::ve
}
} // namespace
void StridedSliceV2CpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
cnode_ptr_ = kernel_node;
input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
return;
bool StridedSliceV2CpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
base_operator_ = base_operator;
kernel_name_ = base_operator->name();
if (inputs.size() != kStridedSliceV2InputsNum && inputs.size() != kStridedSliceV2DynamicInputsNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be " << kStridedSliceV2InputsNum
<< " or " << kStridedSliceV2DynamicInputsNum << ", but got " << inputs.size();
}
// for begin, end, stride are const input
auto begin = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, kAttrBegin);
auto end = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, kAttrEnd);
auto stride = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, kAttrStrides);
InitSliceParam<int64_t>(kernel_node, &begin, &end, &stride);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kStridedSliceV2OutputsNum, kernel_name_);
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match.first) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
dtype_ = inputs[kIndex0]->GetDtype();
dtype_attr_ = inputs[kIndex1]->GetDtype();
return true;
}
int StridedSliceV2CpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
input_shape_ = inputs[kIndex0]->GetShapeVector();
begin_shape_ = inputs[kIndex1]->GetShapeVector();
end_shape_ = inputs[kIndex2]->GetShapeVector();
stride_shape_ = inputs[kIndex3]->GetShapeVector();
output_shape_ = outputs[kIndex0]->GetShapeVector();
if (inputs.size() == kStridedSliceV2DynamicInputsNum) {
return KRET_OK;
}
auto prim = base_operator->GetPrim();
auto begin = GetValue<std::vector<int64_t>>(prim->GetAttr(kAttrBegin));
auto end = GetValue<std::vector<int64_t>>(prim->GetAttr(kAttrEnd));
auto stride = GetValue<std::vector<int64_t>>(prim->GetAttr(kAttrStrides));
InitSliceParam<int64_t>(base_operator, &begin, &end, &stride);
parallel_ = MatchParallelPattern();
if (parallel_) {
InitParallelParam();
}
return KRET_OK;
}
bool StridedSliceV2CpuKernelMod::MatchParallelPattern() {
@ -492,8 +524,8 @@ void StridedSliceV2CpuKernelMod::InitParallelParam() {
}
template <typename T>
void StridedSliceV2CpuKernelMod::InitSliceParam(const CNodePtr &kernel_node, std::vector<T> *begin, std::vector<T> *end,
std::vector<T> *stride) {
void StridedSliceV2CpuKernelMod::InitSliceParam(const BaseOperatorPtr &base_operator, std::vector<T> *begin,
std::vector<T> *end, std::vector<T> *stride) {
static const std::unordered_map<TypeId, std::pair<TypeIdC, int>> type_convert_map = {
{kNumberTypeBool, {::kNumberTypeBool, sizeof(bool)}},
{kNumberTypeInt8, {::kNumberTypeInt8, sizeof(int8_t)}},
@ -521,16 +553,16 @@ void StridedSliceV2CpuKernelMod::InitSliceParam(const CNodePtr &kernel_node, std
slice_param_.data_type = type_pair->second.first;
auto input_shape_pad = input_shape_;
shape_dim_input = input_shape_.size();
FillEmptyDimsST<T>(kernel_node, begin, end, stride, &input_shape_pad);
ParseStrideSliceMasksST<T>(kernel_node, begin, end, stride, input_shape_, shape_dim_input, slice_len);
FillEmptyDimsST<T>(kernel_node, begin, end, stride, &input_shape_pad);
FillEmptyDimsST<T>(base_operator, begin, end, stride, &input_shape_pad);
ParseStrideSliceMasksST<T>(base_operator, begin, end, stride, input_shape_, shape_dim_input, slice_len);
FillEmptyDimsST<T>(base_operator, begin, end, stride, &input_shape_pad);
std::vector<T> &_begin = *begin;
std::vector<T> &_end = *end;
std::vector<T> &_stride = *stride;
for (size_t i = 0; i < DIMENSION_8D; i++) {
slice_param_.in_shape_[i] = SizeToInt(input_shape_pad[i]);
if (dtype_attr == kNumberTypeInt64) {
if (dtype_attr_ == kNumberTypeInt64) {
slice_param_.begins_[i] = LongToInt(_begin[i]);
slice_param_.ends_[i] = LongToInt(_end[i]);
slice_param_.strides_[i] = LongToInt(_stride[i]);
@ -596,66 +628,43 @@ void StridedSliceV2CpuKernelMod::ParallelRun(const uint8_t *input_addr, uint8_t
}
template <typename T>
bool StridedSliceV2CpuKernelMod::StridedSliceV2LaunchDynamicType(const std::vector<kernel::AddressPtr> &inputs) {
auto cnode = cnode_ptr_.lock();
auto begin_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 1);
auto end_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 2);
auto stride_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 3);
if (begin_shape.size() != 1 || end_shape.size() != 1 || stride_shape.size() != 1) {
void StridedSliceV2CpuKernelMod::StridedSliceV2LaunchDynamicType(const std::vector<kernel::AddressPtr> &inputs) {
if (begin_shape_.size() != 1 || end_shape_.size() != 1 || stride_shape_.size() != 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dimension of 'begin', 'end', 'strides' should be equal "
"to 1, but got the dimension of 'begin': "
<< begin_shape.size() << ", the dimension of 'end': " << end_shape.size()
<< ", and the dimension of 'strides': " << stride_shape.size();
<< begin_shape_.size() << ", the dimension of 'end': " << end_shape_.size()
<< ", and the dimension of 'strides': " << stride_shape_.size();
}
auto begin_ptr = static_cast<T *>(inputs[1]->addr);
auto end_ptr = static_cast<T *>(inputs[2]->addr);
auto strides_ptr = static_cast<T *>(inputs[3]->addr);
std::vector<T> begin{begin_ptr, begin_ptr + begin_shape[0]};
std::vector<T> end{end_ptr, end_ptr + end_shape[0]};
std::vector<T> stride{strides_ptr, strides_ptr + stride_shape[0]};
std::vector<T> begin{begin_ptr, begin_ptr + begin_shape_[0]};
std::vector<T> end{end_ptr, end_ptr + end_shape_[0]};
std::vector<T> stride{strides_ptr, strides_ptr + stride_shape_[0]};
slice_len = begin.size();
InitSliceParam<T>(cnode, &begin, &end, &stride);
return true;
InitSliceParam<T>(base_operator_, &begin, &end, &stride);
}
bool StridedSliceV2CpuKernelMod::StridedSliceV2LaunchCal(const std::vector<kernel::AddressPtr> &inputs,
void StridedSliceV2CpuKernelMod::StridedSliceV2LaunchCal(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() != kStridedSliceV2InputsNum && inputs.size() != kStridedSliceV2DynamicInputsNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be " << kStridedSliceV2InputsNum
<< " or " << kStridedSliceV2DynamicInputsNum << ", but got " << inputs.size();
// for begin, end, stride are not const input
if (dtype_attr_ == kNumberTypeInt32) {
StridedSliceV2LaunchDynamicType<int32_t>(inputs);
} else {
StridedSliceV2LaunchDynamicType<int64_t>(inputs);
}
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kStridedSliceV2OutputsNum, kernel_name_);
auto cnode = cnode_ptr_.lock();
size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
if (input_num == kStridedSliceV2DynamicInputsNum) {
bool flag = false;
// for begin, end, stride are not const input
dtype_attr = AnfAlgo::GetInputDeviceDataType(cnode, 1);
if (dtype_attr == kNumberTypeInt32) {
flag = StridedSliceV2LaunchDynamicType<int32_t>(inputs);
} else {
flag = StridedSliceV2LaunchDynamicType<int64_t>(inputs);
}
if (!flag) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the begin、end、stride is calculation error.";
}
parallel_ = MatchParallelPattern();
if (parallel_) {
InitParallelParam();
}
parallel_ = MatchParallelPattern();
if (parallel_) {
InitParallelParam();
}
return true;
}
bool StridedSliceV2CpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /* workspace */,
const std::vector<kernel::AddressPtr> &outputs) {
bool ret = StridedSliceV2LaunchCal(inputs, outputs);
if (ret != true) {
MS_LOG(EXCEPTION) << "For StridedSliceV2 LaunchCal failed.";
if (inputs.size() == kStridedSliceV2DynamicInputsNum) {
StridedSliceV2LaunchCal(inputs, outputs);
}
auto input_addr = static_cast<uint8_t *>(inputs[0]->addr);
auto output_addr = static_cast<uint8_t *>(outputs[0]->addr);

View File

@ -20,6 +20,7 @@
#include <vector>
#include <memory>
#include <string>
#include <map>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
#include "nnacl/fp32/strided_slice_fp32.h"
@ -27,13 +28,15 @@
namespace mindspore {
namespace kernel {
constexpr auto kStridedSliceV2 = "StridedSliceV2";
class StridedSliceV2CpuKernelMod : public DeprecatedNativeCpuKernelMod {
class StridedSliceV2CpuKernelMod : public NativeCpuKernelMod {
public:
StridedSliceV2CpuKernelMod() = default;
~StridedSliceV2CpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
@ -44,33 +47,38 @@ class StridedSliceV2CpuKernelMod : public DeprecatedNativeCpuKernelMod {
enum ParallelStrategy { kOnSplitAxis, kOnOuter };
template <typename T>
bool StridedSliceV2LaunchDynamicType(const std::vector<kernel::AddressPtr> &inputs);
void StridedSliceV2LaunchDynamicType(const std::vector<kernel::AddressPtr> &inputs);
template <typename T>
void InitSliceParam(const CNodePtr &kernel_node, std::vector<T> *begin, std::vector<T> *end, std::vector<T> *stride);
void InitSliceParam(const BaseOperatorPtr &base_operator, std::vector<T> *begin, std::vector<T> *end,
std::vector<T> *stride);
bool MatchParallelPattern();
void InitParallelParam();
void ParallelRun(const uint8_t *input_addr, uint8_t *output_addr, int thread_num);
bool StridedSliceV2LaunchCal(const std::vector<kernel::AddressPtr> &inputs,
void StridedSliceV2LaunchCal(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs);
common::Status RunTaskOnOuter(const uint8_t *input_addr, uint8_t *output_addr, int start_pos);
common::Status RunTaskOnSplitAxis(const uint8_t *input_addr, uint8_t *output_addr, int start_pos);
void ParseMasks(const CNodePtr &kernel_node);
TypeId dtype_;
TypeId dtype_attr;
TypeId dtype_attr_;
int data_size_{4};
int split_axis_{-1};
int inner_{1};
int outer_{1};
int cal_num_per_thread_{1};
bool parallel_{false};
BaseOperatorPtr base_operator_;
size_t inputs_num_;
size_t shape_dim_input;
size_t slice_len;
ParallelStrategy parallel_strategy_{kOnSplitAxis};
ShapeVector input_shape_;
ShapeVector output_shape_;
ShapeVector begin_shape_;
ShapeVector end_shape_;
ShapeVector stride_shape_;
StridedSliceParameter slice_param_;
};
} // namespace kernel