!49555 ops StridedSliceV2, StridedSliceV2Grad supports dynamic shape feature

Merge pull request !49555 from wang_ziqi/br_stridedslice_v2
This commit is contained in:
i-robot 2023-03-01 06:36:00 +00:00 committed by Gitee
commit 120f7b597e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 174 additions and 143 deletions

View File

@ -259,21 +259,24 @@ std::vector<bool> GradDec2Bin(const int64_t &mask) {
} }
template <typename T> template <typename T>
void ParseStrideSliceGradMasksST(const CNodePtr &kernel_node, std::vector<T> *begin, std::vector<T> *end, void ParseStrideSliceGradMasksST(const BaseOperatorPtr &base_operator, std::vector<T> *begin, std::vector<T> *end,
std::vector<T> *stride, ShapeVector *input_shape, const ShapeVector output_shape, std::vector<T> *stride, ShapeVector *input_shape, const ShapeVector output_shape,
int shape_dim_output, int slice_len) { int shape_dim_output, int slice_len) {
std::vector<T> &_begin_attr = *begin; std::vector<T> &_begin_attr = *begin;
std::vector<T> &_end_attr = *end; std::vector<T> &_end_attr = *end;
std::vector<T> &_stride_attr = *stride; std::vector<T> &_stride_attr = *stride;
auto begin_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrBeginMask); auto prim = base_operator->GetPrim();
MS_EXCEPTION_IF_NULL(prim);
auto begin_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrBeginMask));
auto begin_mask = GradDec2Bin(begin_mask_int); auto begin_mask = GradDec2Bin(begin_mask_int);
auto end_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrEndMask); auto end_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrEndMask));
auto end_mask = GradDec2Bin(end_mask_int); auto end_mask = GradDec2Bin(end_mask_int);
auto ellipsis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrEllipsisMask); auto ellipsis_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrEllipsisMask));
auto ellipsis_mask = GradDec2Bin(ellipsis_mask_int); auto ellipsis_mask = GradDec2Bin(ellipsis_mask_int);
auto new_axis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrNewAxisMask); auto new_axis_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrNewAxisMask));
auto new_axis_mask = GradDec2Bin(new_axis_mask_int); auto new_axis_mask = GradDec2Bin(new_axis_mask_int);
auto shrink_axis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrShrinkAxisMask); auto shrink_axis_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrShrinkAxisMask));
auto shrink_axis_mask = GradDec2Bin(shrink_axis_mask_int); auto shrink_axis_mask = GradDec2Bin(shrink_axis_mask_int);
int i = 0; int i = 0;
int j = 0; int j = 0;
@ -366,33 +369,54 @@ void FillEmptyDimsSTGrad(std::vector<T> *begin, std::vector<T> *end, std::vector
} }
} // namespace } // namespace
void StridedSliceV2GradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { bool StridedSliceV2GradCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
MS_EXCEPTION_IF_NULL(kernel_node); const std::vector<KernelTensorPtr> &inputs,
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); const std::vector<KernelTensorPtr> &outputs) {
cnode_ptr_ = kernel_node; MS_EXCEPTION_IF_NULL(base_operator);
ClearVectors(); base_operator_ = base_operator;
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex4); kernel_name_ = base_operator->name();
if (input_shape.size() > kStridedSliceV2GradMaxInputShapeSize) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input tensor must be 8D or lower, but got "
<< input_shape.size() << "D.";
}
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex4);
dtype_grad_attr = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num == kStridedSliceV2GradDynamicInputsNum) { // Dynamic Shape
return;
}
// in the case that begin, end, size, stride are const value.
std::vector<int64_t> begin_me = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, BEGIN);
(void)std::transform(begin_me.begin(), begin_me.end(), std::back_inserter(begin_),
[](const int64_t &value) { return LongToInt(value); });
auto prim = common::AnfAlgo::GetCNodePrimitive(kernel_node);
MS_EXCEPTION_IF_NULL(prim);
auto strides = prim->GetAttr(STRIDES);
std::vector<int64_t> strides_me = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, STRIDES); if (inputs.empty()) {
std::vector<int64_t> end_me = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, END); MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input can not be empty.";
}
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match.first) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
dtype_ = inputs[kIndex4]->GetDtype();
dtype_grad_attr = inputs[kIndex1]->GetDtype();
return true;
}
int StridedSliceV2GradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
ClearVectors();
begin_shape_ = inputs_[kIndex1]->GetShapeVector();
end_shape_ = inputs_[kIndex2]->GetShapeVector();
stride_shape_ = inputs_[kIndex3]->GetShapeVector();
input_shape_ = inputs[kIndex4]->GetShapeVector();
output_shape_ = outputs[kIndex0]->GetShapeVector();
if (inputs.size() == kStridedSliceV2GradDynamicInputsNum) { // Dynamic Shape
return KRET_OK;
}
auto prim = base_operator->GetPrim();
MS_EXCEPTION_IF_NULL(prim);
std::vector<int64_t> strides_me = GetValue<std::vector<int64_t>>(prim->GetAttr(STRIDES));
std::vector<int64_t> end_me = GetValue<std::vector<int64_t>>(prim->GetAttr(END));
(void)std::transform(strides_me.begin(), strides_me.end(), std::back_inserter(strides_), (void)std::transform(strides_me.begin(), strides_me.end(), std::back_inserter(strides_),
[](const int64_t &value) { return LongToInt(value); }); [](const int64_t &value) { return LongToInt(value); });
(void)std::transform(end_me.begin(), end_me.end(), std::back_inserter(end_), (void)std::transform(end_me.begin(), end_me.end(), std::back_inserter(end_),
@ -405,6 +429,7 @@ void StridedSliceV2GradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
} }
FormatArgs(true); FormatArgs(true);
ExpandAllMemberDims(kStridedSliceV2GradMaxInputShapeSize); ExpandAllMemberDims(kStridedSliceV2GradMaxInputShapeSize);
return KRET_OK;
} }
void StridedSliceV2GradCpuKernelMod::ClearVectors() { void StridedSliceV2GradCpuKernelMod::ClearVectors() {
@ -412,8 +437,6 @@ void StridedSliceV2GradCpuKernelMod::ClearVectors() {
size_.clear(); size_.clear();
strides_.clear(); strides_.clear();
end_.clear(); end_.clear();
input_element_num_.clear();
output_element_num_.clear();
input_shape_.clear(); input_shape_.clear();
output_shape_.clear(); output_shape_.clear();
} }
@ -449,33 +472,24 @@ void StridedSliceV2GradCpuKernelMod::ExpandAllMemberDims(size_t expand_dims) {
// init for dynamic shape // init for dynamic shape
template <typename T> template <typename T>
void StridedSliceV2GradCpuKernelMod::InitParams(const std::vector<kernel::AddressPtr> &inputs) { void StridedSliceV2GradCpuKernelMod::InitParams(const std::vector<kernel::AddressPtr> &inputs) {
auto cnode = cnode_ptr_.lock(); if (begin_shape_.size() != 1 || end_shape_.size() != 1 || stride_shape_.size() != 1) {
ClearVectors();
output_shape_ = common::AnfAlgo::GetOutputInferShape(cnode, 0);
std::string kernel_name = common::AnfAlgo::GetCNodeName(cnode);
auto begin_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 1);
auto begin_ptr = static_cast<T *>(inputs[1]->addr);
std::vector<T> begin{begin_ptr, begin_ptr + begin_shape[0]};
auto end_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, kIndex2);
auto stride_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, kIndex3);
if (begin_shape.size() != 1 || end_shape.size() != 1 || stride_shape.size() != 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dimensions of 'begin', 'end', 'strides' must be 1, " << "', the dimensions of 'begin', 'end', 'strides' must be 1, "
"but got the dimension of 'begin': " "but got the dimension of 'begin': "
<< begin_shape.size() << ", the dimension of 'end': " << end_shape.size() << begin_shape_.size() << ", the dimension of 'end': " << end_shape_.size()
<< ", and the dimension of 'strides': " << stride_shape.size(); << ", and the dimension of 'strides': " << stride_shape_.size();
} }
auto begin_ptr = static_cast<T *>(inputs[1]->addr);
std::vector<T> begin{begin_ptr, begin_ptr + begin_shape_[0]};
auto end_ptr = static_cast<T *>(inputs[kIndex2]->addr); auto end_ptr = static_cast<T *>(inputs[kIndex2]->addr);
auto strides_ptr = static_cast<T *>(inputs[kIndex3]->addr); auto strides_ptr = static_cast<T *>(inputs[kIndex3]->addr);
std::vector<T> end{end_ptr, end_ptr + end_shape[0]}; std::vector<T> end{end_ptr, end_ptr + end_shape_[0]};
std::vector<T> strides{strides_ptr, strides_ptr + stride_shape[0]}; std::vector<T> strides{strides_ptr, strides_ptr + stride_shape_[0]};
input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, kIndex4);
shape_dim_output = SizeToInt(output_shape_.size()); shape_dim_output = SizeToInt(output_shape_.size());
slice_len = SizeToInt(begin.size()); slice_len = SizeToInt(begin.size());
FillEmptyDimsSTGrad<T>(&begin, &end, &strides, &input_shape_, &output_shape_); FillEmptyDimsSTGrad<T>(&begin, &end, &strides, &input_shape_, &output_shape_);
ParseStrideSliceGradMasksST<T>(cnode, &begin, &end, &strides, &input_shape_, output_shape_, shape_dim_output, ParseStrideSliceGradMasksST<T>(base_operator_, &begin, &end, &strides, &input_shape_, output_shape_, shape_dim_output,
slice_len); slice_len);
FillEmptyDimsSTGrad<T>(&begin, &end, &strides, &input_shape_, &output_shape_); FillEmptyDimsSTGrad<T>(&begin, &end, &strides, &input_shape_, &output_shape_);
(void)std::transform(begin.begin(), begin.end(), std::back_inserter(begin_), [](const T &value) { return value; }); (void)std::transform(begin.begin(), begin.end(), std::back_inserter(begin_), [](const T &value) { return value; });
@ -494,10 +508,6 @@ void StridedSliceV2GradCpuKernelMod::InitParams(const std::vector<kernel::Addres
bool StridedSliceV2GradCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, bool StridedSliceV2GradCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.empty()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input can not be empty.";
}
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
bool ret = true; bool ret = true;
if (dtype_ == kNumberTypeInt32) { if (dtype_ == kNumberTypeInt32) {
ret = LaunchKernel<int32_t>(inputs, outputs); ret = LaunchKernel<int32_t>(inputs, outputs);

View File

@ -20,7 +20,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include <map>
#include "plugin/device/cpu/kernel/nnacl/fp32_grad/strided_slice_grad.h" #include "plugin/device/cpu/kernel/nnacl/fp32_grad/strided_slice_grad.h"
#include "plugin/device/cpu/kernel/cpu_kernel.h" #include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h" #include "plugin/factory/ms_factory.h"
@ -30,7 +30,7 @@ namespace kernel {
constexpr auto kStridedSliceV2Grad = "StridedSliceV2Grad"; constexpr auto kStridedSliceV2Grad = "StridedSliceV2Grad";
constexpr auto kUnknown = "Unknown"; constexpr auto kUnknown = "Unknown";
class StridedSliceV2GradCpuKernelMod : public DeprecatedNativeCpuKernelMod { class StridedSliceV2GradCpuKernelMod : public NativeCpuKernelMod {
public: public:
StridedSliceV2GradCpuKernelMod() = default; StridedSliceV2GradCpuKernelMod() = default;
@ -38,8 +38,10 @@ class StridedSliceV2GradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
~StridedSliceV2GradCpuKernelMod() override = default; ~StridedSliceV2GradCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override; bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
@ -68,16 +70,18 @@ class StridedSliceV2GradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
template <typename T> template <typename T>
bool CalStridedSliceV2Grad(T *input, T *output); bool CalStridedSliceV2Grad(T *input, T *output);
BaseOperatorPtr base_operator_;
std::vector<int> begin_; std::vector<int> begin_;
std::vector<int> end_; std::vector<int> end_;
std::vector<int> strides_; std::vector<int> strides_;
std::vector<int> size_; std::vector<int> size_;
ShapeVector input_shape_;
int shape_dim_output{0}; int shape_dim_output{0};
int slice_len{0}; int slice_len{0};
std::vector<size_t> input_element_num_; ShapeVector input_shape_;
ShapeVector begin_shape_;
ShapeVector end_shape_;
ShapeVector stride_shape_;
ShapeVector output_shape_; ShapeVector output_shape_;
std::vector<size_t> output_element_num_;
TypeId dtype_{kTypeUnknown}; TypeId dtype_{kTypeUnknown};
TypeId dtype_grad_attr{kTypeUnknown}; TypeId dtype_grad_attr{kTypeUnknown};
std::string kernel_type_{kUnknown}; std::string kernel_type_{kUnknown};

View File

@ -311,22 +311,22 @@ static std::map<std::string, std::vector<KernelAttr>> support_list_map = {
.AddOutputAttr(kNumberTypeComplex128)}}}; .AddOutputAttr(kNumberTypeComplex128)}}};
template <typename T> template <typename T>
void ParseStrideSliceMasksST(const CNodePtr &kernel_node, std::vector<T> *begin, std::vector<T> *end, void ParseStrideSliceMasksST(const BaseOperatorPtr &base_operator, std::vector<T> *begin, std::vector<T> *end,
std::vector<T> *stride, const ShapeVector &input_shape, size_t shape_dim_input, std::vector<T> *stride, const ShapeVector &input_shape, size_t shape_dim_input,
size_t slice_len) { size_t slice_len) {
std::vector<T> &_begin_attr = *begin; std::vector<T> &_begin_attr = *begin;
std::vector<T> &_end_attr = *end; std::vector<T> &_end_attr = *end;
std::vector<T> &_stride_attr = *stride; std::vector<T> &_stride_attr = *stride;
auto prim = base_operator->GetPrim();
auto begin_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrBeginMask); auto begin_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrBeginMask));
auto begin_mask = Dec2Bin(begin_mask_int); auto begin_mask = Dec2Bin(begin_mask_int);
auto end_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrEndMask); auto end_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrEndMask));
auto end_mask = Dec2Bin(end_mask_int); auto end_mask = Dec2Bin(end_mask_int);
auto ellipsis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrEllipsisMask); auto ellipsis_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrEllipsisMask));
auto ellipsis_mask = Dec2Bin(ellipsis_mask_int); auto ellipsis_mask = Dec2Bin(ellipsis_mask_int);
auto new_axis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrNewAxisMask); auto new_axis_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrNewAxisMask));
auto new_axis_mask = Dec2Bin(new_axis_mask_int); auto new_axis_mask = Dec2Bin(new_axis_mask_int);
auto shrink_axis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrShrinkAxisMask); auto shrink_axis_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrShrinkAxisMask));
auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_int); auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_int);
size_t i = 0; size_t i = 0;
size_t j = 0; size_t j = 0;
@ -380,14 +380,14 @@ void ParseStrideSliceMasksST(const CNodePtr &kernel_node, std::vector<T> *begin,
} }
template <typename T> template <typename T>
void FillEmptyDimsST(const CNodePtr &kernel_node, std::vector<T> *begin, std::vector<T> *end, std::vector<T> *stride, void FillEmptyDimsST(const BaseOperatorPtr &base_operator, std::vector<T> *begin, std::vector<T> *end,
ShapeVector *input_shape) { std::vector<T> *stride, ShapeVector *input_shape) {
std::vector<T> &_begin = *begin; std::vector<T> &_begin = *begin;
std::vector<T> &_end = *end; std::vector<T> &_end = *end;
std::vector<T> &_stride = *stride; std::vector<T> &_stride = *stride;
auto &_input_shape = *input_shape; auto &_input_shape = *input_shape;
if (_begin.size() != _end.size() || _begin.size() != _stride.size() || _begin.size() > _input_shape.size()) { if (_begin.size() != _end.size() || _begin.size() != _stride.size() || _begin.size() > _input_shape.size()) {
MS_LOG(EXCEPTION) << "For '" << common::AnfAlgo::GetCNodeName(kernel_node) MS_LOG(EXCEPTION) << "For '" << base_operator->name()
<< "', the length of 'begin', 'stride' and 'end' should be equal " << "', the length of 'begin', 'stride' and 'end' should be equal "
"and less than or equal to the dimension of 'input_x', but got the length of 'begin': " "and less than or equal to the dimension of 'input_x', but got the length of 'begin': "
<< _begin.size() << ", the length of 'stride': " << _stride.size() << _begin.size() << ", the length of 'stride': " << _stride.size()
@ -416,27 +416,59 @@ void FillEmptyDimsST(const CNodePtr &kernel_node, std::vector<T> *begin, std::ve
} }
} // namespace } // namespace
void StridedSliceV2CpuKernelMod::InitKernel(const CNodePtr &kernel_node) { bool StridedSliceV2CpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
MS_EXCEPTION_IF_NULL(kernel_node); const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); MS_EXCEPTION_IF_NULL(base_operator);
cnode_ptr_ = kernel_node; base_operator_ = base_operator;
input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); kernel_name_ = base_operator->name();
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); if (inputs.size() != kStridedSliceV2InputsNum && inputs.size() != kStridedSliceV2DynamicInputsNum) {
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be " << kStridedSliceV2InputsNum
if (input_num != 1) { << " or " << kStridedSliceV2DynamicInputsNum << ", but got " << inputs.size();
return;
} }
// for begin, end, stride are const input CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kStridedSliceV2OutputsNum, kernel_name_);
auto begin = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, kAttrBegin);
auto end = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, kAttrEnd); auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto stride = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, kAttrStrides); auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport());
InitSliceParam<int64_t>(kernel_node, &begin, &end, &stride); if (!is_match.first) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
dtype_ = inputs[kIndex0]->GetDtype();
dtype_attr_ = inputs[kIndex1]->GetDtype();
return true;
}
int StridedSliceV2CpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
input_shape_ = inputs[kIndex0]->GetShapeVector();
begin_shape_ = inputs[kIndex1]->GetShapeVector();
end_shape_ = inputs[kIndex2]->GetShapeVector();
stride_shape_ = inputs[kIndex3]->GetShapeVector();
output_shape_ = outputs[kIndex0]->GetShapeVector();
if (inputs.size() == kStridedSliceV2DynamicInputsNum) {
return KRET_OK;
}
auto prim = base_operator->GetPrim();
auto begin = GetValue<std::vector<int64_t>>(prim->GetAttr(kAttrBegin));
auto end = GetValue<std::vector<int64_t>>(prim->GetAttr(kAttrEnd));
auto stride = GetValue<std::vector<int64_t>>(prim->GetAttr(kAttrStrides));
InitSliceParam<int64_t>(base_operator, &begin, &end, &stride);
parallel_ = MatchParallelPattern(); parallel_ = MatchParallelPattern();
if (parallel_) { if (parallel_) {
InitParallelParam(); InitParallelParam();
} }
return KRET_OK;
} }
bool StridedSliceV2CpuKernelMod::MatchParallelPattern() { bool StridedSliceV2CpuKernelMod::MatchParallelPattern() {
@ -492,8 +524,8 @@ void StridedSliceV2CpuKernelMod::InitParallelParam() {
} }
template <typename T> template <typename T>
void StridedSliceV2CpuKernelMod::InitSliceParam(const CNodePtr &kernel_node, std::vector<T> *begin, std::vector<T> *end, void StridedSliceV2CpuKernelMod::InitSliceParam(const BaseOperatorPtr &base_operator, std::vector<T> *begin,
std::vector<T> *stride) { std::vector<T> *end, std::vector<T> *stride) {
static const std::unordered_map<TypeId, std::pair<TypeIdC, int>> type_convert_map = { static const std::unordered_map<TypeId, std::pair<TypeIdC, int>> type_convert_map = {
{kNumberTypeBool, {::kNumberTypeBool, sizeof(bool)}}, {kNumberTypeBool, {::kNumberTypeBool, sizeof(bool)}},
{kNumberTypeInt8, {::kNumberTypeInt8, sizeof(int8_t)}}, {kNumberTypeInt8, {::kNumberTypeInt8, sizeof(int8_t)}},
@ -521,16 +553,16 @@ void StridedSliceV2CpuKernelMod::InitSliceParam(const CNodePtr &kernel_node, std
slice_param_.data_type = type_pair->second.first; slice_param_.data_type = type_pair->second.first;
auto input_shape_pad = input_shape_; auto input_shape_pad = input_shape_;
shape_dim_input = input_shape_.size(); shape_dim_input = input_shape_.size();
FillEmptyDimsST<T>(kernel_node, begin, end, stride, &input_shape_pad); FillEmptyDimsST<T>(base_operator, begin, end, stride, &input_shape_pad);
ParseStrideSliceMasksST<T>(kernel_node, begin, end, stride, input_shape_, shape_dim_input, slice_len); ParseStrideSliceMasksST<T>(base_operator, begin, end, stride, input_shape_, shape_dim_input, slice_len);
FillEmptyDimsST<T>(kernel_node, begin, end, stride, &input_shape_pad); FillEmptyDimsST<T>(base_operator, begin, end, stride, &input_shape_pad);
std::vector<T> &_begin = *begin; std::vector<T> &_begin = *begin;
std::vector<T> &_end = *end; std::vector<T> &_end = *end;
std::vector<T> &_stride = *stride; std::vector<T> &_stride = *stride;
for (size_t i = 0; i < DIMENSION_8D; i++) { for (size_t i = 0; i < DIMENSION_8D; i++) {
slice_param_.in_shape_[i] = SizeToInt(input_shape_pad[i]); slice_param_.in_shape_[i] = SizeToInt(input_shape_pad[i]);
if (dtype_attr == kNumberTypeInt64) { if (dtype_attr_ == kNumberTypeInt64) {
slice_param_.begins_[i] = LongToInt(_begin[i]); slice_param_.begins_[i] = LongToInt(_begin[i]);
slice_param_.ends_[i] = LongToInt(_end[i]); slice_param_.ends_[i] = LongToInt(_end[i]);
slice_param_.strides_[i] = LongToInt(_stride[i]); slice_param_.strides_[i] = LongToInt(_stride[i]);
@ -596,66 +628,43 @@ void StridedSliceV2CpuKernelMod::ParallelRun(const uint8_t *input_addr, uint8_t
} }
template <typename T> template <typename T>
bool StridedSliceV2CpuKernelMod::StridedSliceV2LaunchDynamicType(const std::vector<kernel::AddressPtr> &inputs) { void StridedSliceV2CpuKernelMod::StridedSliceV2LaunchDynamicType(const std::vector<kernel::AddressPtr> &inputs) {
auto cnode = cnode_ptr_.lock(); if (begin_shape_.size() != 1 || end_shape_.size() != 1 || stride_shape_.size() != 1) {
auto begin_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 1);
auto end_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 2);
auto stride_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 3);
if (begin_shape.size() != 1 || end_shape.size() != 1 || stride_shape.size() != 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dimension of 'begin', 'end', 'strides' should be equal " << "', the dimension of 'begin', 'end', 'strides' should be equal "
"to 1, but got the dimension of 'begin': " "to 1, but got the dimension of 'begin': "
<< begin_shape.size() << ", the dimension of 'end': " << end_shape.size() << begin_shape_.size() << ", the dimension of 'end': " << end_shape_.size()
<< ", and the dimension of 'strides': " << stride_shape.size(); << ", and the dimension of 'strides': " << stride_shape_.size();
} }
auto begin_ptr = static_cast<T *>(inputs[1]->addr); auto begin_ptr = static_cast<T *>(inputs[1]->addr);
auto end_ptr = static_cast<T *>(inputs[2]->addr); auto end_ptr = static_cast<T *>(inputs[2]->addr);
auto strides_ptr = static_cast<T *>(inputs[3]->addr); auto strides_ptr = static_cast<T *>(inputs[3]->addr);
std::vector<T> begin{begin_ptr, begin_ptr + begin_shape[0]}; std::vector<T> begin{begin_ptr, begin_ptr + begin_shape_[0]};
std::vector<T> end{end_ptr, end_ptr + end_shape[0]}; std::vector<T> end{end_ptr, end_ptr + end_shape_[0]};
std::vector<T> stride{strides_ptr, strides_ptr + stride_shape[0]}; std::vector<T> stride{strides_ptr, strides_ptr + stride_shape_[0]};
slice_len = begin.size(); slice_len = begin.size();
InitSliceParam<T>(cnode, &begin, &end, &stride); InitSliceParam<T>(base_operator_, &begin, &end, &stride);
return true;
} }
bool StridedSliceV2CpuKernelMod::StridedSliceV2LaunchCal(const std::vector<kernel::AddressPtr> &inputs, void StridedSliceV2CpuKernelMod::StridedSliceV2LaunchCal(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() != kStridedSliceV2InputsNum && inputs.size() != kStridedSliceV2DynamicInputsNum) { // for begin, end, stride are not const input
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be " << kStridedSliceV2InputsNum if (dtype_attr_ == kNumberTypeInt32) {
<< " or " << kStridedSliceV2DynamicInputsNum << ", but got " << inputs.size(); StridedSliceV2LaunchDynamicType<int32_t>(inputs);
} else {
StridedSliceV2LaunchDynamicType<int64_t>(inputs);
} }
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kStridedSliceV2OutputsNum, kernel_name_); parallel_ = MatchParallelPattern();
if (parallel_) {
auto cnode = cnode_ptr_.lock(); InitParallelParam();
size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
if (input_num == kStridedSliceV2DynamicInputsNum) {
bool flag = false;
// for begin, end, stride are not const input
dtype_attr = AnfAlgo::GetInputDeviceDataType(cnode, 1);
if (dtype_attr == kNumberTypeInt32) {
flag = StridedSliceV2LaunchDynamicType<int32_t>(inputs);
} else {
flag = StridedSliceV2LaunchDynamicType<int64_t>(inputs);
}
if (!flag) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the begin、end、stride is calculation error.";
}
parallel_ = MatchParallelPattern();
if (parallel_) {
InitParallelParam();
}
} }
return true;
} }
bool StridedSliceV2CpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, bool StridedSliceV2CpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /* workspace */, const std::vector<kernel::AddressPtr> & /* workspace */,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
bool ret = StridedSliceV2LaunchCal(inputs, outputs); if (inputs.size() == kStridedSliceV2DynamicInputsNum) {
if (ret != true) { StridedSliceV2LaunchCal(inputs, outputs);
MS_LOG(EXCEPTION) << "For StridedSliceV2 LaunchCal failed.";
} }
auto input_addr = static_cast<uint8_t *>(inputs[0]->addr); auto input_addr = static_cast<uint8_t *>(inputs[0]->addr);
auto output_addr = static_cast<uint8_t *>(outputs[0]->addr); auto output_addr = static_cast<uint8_t *>(outputs[0]->addr);

View File

@ -20,6 +20,7 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <string> #include <string>
#include <map>
#include "plugin/device/cpu/kernel/cpu_kernel.h" #include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h" #include "plugin/factory/ms_factory.h"
#include "nnacl/fp32/strided_slice_fp32.h" #include "nnacl/fp32/strided_slice_fp32.h"
@ -27,13 +28,15 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
constexpr auto kStridedSliceV2 = "StridedSliceV2"; constexpr auto kStridedSliceV2 = "StridedSliceV2";
class StridedSliceV2CpuKernelMod : public DeprecatedNativeCpuKernelMod { class StridedSliceV2CpuKernelMod : public NativeCpuKernelMod {
public: public:
StridedSliceV2CpuKernelMod() = default; StridedSliceV2CpuKernelMod() = default;
~StridedSliceV2CpuKernelMod() override = default; ~StridedSliceV2CpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override; bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
@ -44,33 +47,38 @@ class StridedSliceV2CpuKernelMod : public DeprecatedNativeCpuKernelMod {
enum ParallelStrategy { kOnSplitAxis, kOnOuter }; enum ParallelStrategy { kOnSplitAxis, kOnOuter };
template <typename T> template <typename T>
bool StridedSliceV2LaunchDynamicType(const std::vector<kernel::AddressPtr> &inputs); void StridedSliceV2LaunchDynamicType(const std::vector<kernel::AddressPtr> &inputs);
template <typename T> template <typename T>
void InitSliceParam(const CNodePtr &kernel_node, std::vector<T> *begin, std::vector<T> *end, std::vector<T> *stride); void InitSliceParam(const BaseOperatorPtr &base_operator, std::vector<T> *begin, std::vector<T> *end,
std::vector<T> *stride);
bool MatchParallelPattern(); bool MatchParallelPattern();
void InitParallelParam(); void InitParallelParam();
void ParallelRun(const uint8_t *input_addr, uint8_t *output_addr, int thread_num); void ParallelRun(const uint8_t *input_addr, uint8_t *output_addr, int thread_num);
bool StridedSliceV2LaunchCal(const std::vector<kernel::AddressPtr> &inputs, void StridedSliceV2LaunchCal(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs); const std::vector<kernel::AddressPtr> &outputs);
common::Status RunTaskOnOuter(const uint8_t *input_addr, uint8_t *output_addr, int start_pos); common::Status RunTaskOnOuter(const uint8_t *input_addr, uint8_t *output_addr, int start_pos);
common::Status RunTaskOnSplitAxis(const uint8_t *input_addr, uint8_t *output_addr, int start_pos); common::Status RunTaskOnSplitAxis(const uint8_t *input_addr, uint8_t *output_addr, int start_pos);
void ParseMasks(const CNodePtr &kernel_node);
TypeId dtype_; TypeId dtype_;
TypeId dtype_attr; TypeId dtype_attr_;
int data_size_{4}; int data_size_{4};
int split_axis_{-1}; int split_axis_{-1};
int inner_{1}; int inner_{1};
int outer_{1}; int outer_{1};
int cal_num_per_thread_{1}; int cal_num_per_thread_{1};
bool parallel_{false}; bool parallel_{false};
BaseOperatorPtr base_operator_;
size_t inputs_num_;
size_t shape_dim_input; size_t shape_dim_input;
size_t slice_len; size_t slice_len;
ParallelStrategy parallel_strategy_{kOnSplitAxis}; ParallelStrategy parallel_strategy_{kOnSplitAxis};
ShapeVector input_shape_; ShapeVector input_shape_;
ShapeVector output_shape_; ShapeVector output_shape_;
ShapeVector begin_shape_;
ShapeVector end_shape_;
ShapeVector stride_shape_;
StridedSliceParameter slice_param_; StridedSliceParameter slice_param_;
}; };
} // namespace kernel } // namespace kernel