forked from mindspore-Ecosystem/mindspore
ops StridedSliceV2, StridedSliceV2Grad supports dynamic shape feature
type: feature reason: add codes to support dynamic shape for StridedSliceV2, StridedSliceV2Grad. ------ Signed-off-by: wang_ziqi <wangziqi4@huawei.com>
This commit is contained in:
parent
33be16d103
commit
1ca07d19e9
|
@ -259,21 +259,24 @@ std::vector<bool> GradDec2Bin(const int64_t &mask) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ParseStrideSliceGradMasksST(const CNodePtr &kernel_node, std::vector<T> *begin, std::vector<T> *end,
|
||||
void ParseStrideSliceGradMasksST(const BaseOperatorPtr &base_operator, std::vector<T> *begin, std::vector<T> *end,
|
||||
std::vector<T> *stride, ShapeVector *input_shape, const ShapeVector output_shape,
|
||||
int shape_dim_output, int slice_len) {
|
||||
std::vector<T> &_begin_attr = *begin;
|
||||
std::vector<T> &_end_attr = *end;
|
||||
std::vector<T> &_stride_attr = *stride;
|
||||
auto begin_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrBeginMask);
|
||||
auto prim = base_operator->GetPrim();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
|
||||
auto begin_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrBeginMask));
|
||||
auto begin_mask = GradDec2Bin(begin_mask_int);
|
||||
auto end_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrEndMask);
|
||||
auto end_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrEndMask));
|
||||
auto end_mask = GradDec2Bin(end_mask_int);
|
||||
auto ellipsis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrEllipsisMask);
|
||||
auto ellipsis_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrEllipsisMask));
|
||||
auto ellipsis_mask = GradDec2Bin(ellipsis_mask_int);
|
||||
auto new_axis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrNewAxisMask);
|
||||
auto new_axis_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrNewAxisMask));
|
||||
auto new_axis_mask = GradDec2Bin(new_axis_mask_int);
|
||||
auto shrink_axis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrShrinkAxisMask);
|
||||
auto shrink_axis_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrShrinkAxisMask));
|
||||
auto shrink_axis_mask = GradDec2Bin(shrink_axis_mask_int);
|
||||
int i = 0;
|
||||
int j = 0;
|
||||
|
@ -366,33 +369,54 @@ void FillEmptyDimsSTGrad(std::vector<T> *begin, std::vector<T> *end, std::vector
|
|||
}
|
||||
} // namespace
|
||||
|
||||
void StridedSliceV2GradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
cnode_ptr_ = kernel_node;
|
||||
ClearVectors();
|
||||
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex4);
|
||||
if (input_shape.size() > kStridedSliceV2GradMaxInputShapeSize) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input tensor must be 8D or lower, but got "
|
||||
<< input_shape.size() << "D.";
|
||||
}
|
||||
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex4);
|
||||
dtype_grad_attr = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num == kStridedSliceV2GradDynamicInputsNum) { // Dynamic Shape
|
||||
return;
|
||||
}
|
||||
// in the case that begin, end, size, stride are const value.
|
||||
std::vector<int64_t> begin_me = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, BEGIN);
|
||||
(void)std::transform(begin_me.begin(), begin_me.end(), std::back_inserter(begin_),
|
||||
[](const int64_t &value) { return LongToInt(value); });
|
||||
auto prim = common::AnfAlgo::GetCNodePrimitive(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto strides = prim->GetAttr(STRIDES);
|
||||
bool StridedSliceV2GradCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
base_operator_ = base_operator;
|
||||
kernel_name_ = base_operator->name();
|
||||
|
||||
std::vector<int64_t> strides_me = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, STRIDES);
|
||||
std::vector<int64_t> end_me = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, END);
|
||||
if (inputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input can not be empty.";
|
||||
}
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match.first) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
|
||||
dtype_ = inputs[kIndex4]->GetDtype();
|
||||
dtype_grad_attr = inputs[kIndex1]->GetDtype();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int StridedSliceV2GradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
ClearVectors();
|
||||
begin_shape_ = inputs_[kIndex1]->GetShapeVector();
|
||||
end_shape_ = inputs_[kIndex2]->GetShapeVector();
|
||||
stride_shape_ = inputs_[kIndex3]->GetShapeVector();
|
||||
input_shape_ = inputs[kIndex4]->GetShapeVector();
|
||||
output_shape_ = outputs[kIndex0]->GetShapeVector();
|
||||
|
||||
if (inputs.size() == kStridedSliceV2GradDynamicInputsNum) { // Dynamic Shape
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
auto prim = base_operator->GetPrim();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<int64_t> strides_me = GetValue<std::vector<int64_t>>(prim->GetAttr(STRIDES));
|
||||
std::vector<int64_t> end_me = GetValue<std::vector<int64_t>>(prim->GetAttr(END));
|
||||
(void)std::transform(strides_me.begin(), strides_me.end(), std::back_inserter(strides_),
|
||||
[](const int64_t &value) { return LongToInt(value); });
|
||||
(void)std::transform(end_me.begin(), end_me.end(), std::back_inserter(end_),
|
||||
|
@ -405,6 +429,7 @@ void StridedSliceV2GradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
}
|
||||
FormatArgs(true);
|
||||
ExpandAllMemberDims(kStridedSliceV2GradMaxInputShapeSize);
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
void StridedSliceV2GradCpuKernelMod::ClearVectors() {
|
||||
|
@ -412,8 +437,6 @@ void StridedSliceV2GradCpuKernelMod::ClearVectors() {
|
|||
size_.clear();
|
||||
strides_.clear();
|
||||
end_.clear();
|
||||
input_element_num_.clear();
|
||||
output_element_num_.clear();
|
||||
input_shape_.clear();
|
||||
output_shape_.clear();
|
||||
}
|
||||
|
@ -449,33 +472,24 @@ void StridedSliceV2GradCpuKernelMod::ExpandAllMemberDims(size_t expand_dims) {
|
|||
// init for dynamic shape
|
||||
template <typename T>
|
||||
void StridedSliceV2GradCpuKernelMod::InitParams(const std::vector<kernel::AddressPtr> &inputs) {
|
||||
auto cnode = cnode_ptr_.lock();
|
||||
ClearVectors();
|
||||
output_shape_ = common::AnfAlgo::GetOutputInferShape(cnode, 0);
|
||||
std::string kernel_name = common::AnfAlgo::GetCNodeName(cnode);
|
||||
auto begin_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 1);
|
||||
auto begin_ptr = static_cast<T *>(inputs[1]->addr);
|
||||
std::vector<T> begin{begin_ptr, begin_ptr + begin_shape[0]};
|
||||
|
||||
auto end_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, kIndex2);
|
||||
auto stride_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, kIndex3);
|
||||
if (begin_shape.size() != 1 || end_shape.size() != 1 || stride_shape.size() != 1) {
|
||||
if (begin_shape_.size() != 1 || end_shape_.size() != 1 || stride_shape_.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dimensions of 'begin', 'end', 'strides' must be 1, "
|
||||
"but got the dimension of 'begin': "
|
||||
<< begin_shape.size() << ", the dimension of 'end': " << end_shape.size()
|
||||
<< ", and the dimension of 'strides': " << stride_shape.size();
|
||||
<< begin_shape_.size() << ", the dimension of 'end': " << end_shape_.size()
|
||||
<< ", and the dimension of 'strides': " << stride_shape_.size();
|
||||
}
|
||||
auto begin_ptr = static_cast<T *>(inputs[1]->addr);
|
||||
std::vector<T> begin{begin_ptr, begin_ptr + begin_shape_[0]};
|
||||
auto end_ptr = static_cast<T *>(inputs[kIndex2]->addr);
|
||||
auto strides_ptr = static_cast<T *>(inputs[kIndex3]->addr);
|
||||
|
||||
std::vector<T> end{end_ptr, end_ptr + end_shape[0]};
|
||||
std::vector<T> strides{strides_ptr, strides_ptr + stride_shape[0]};
|
||||
input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, kIndex4);
|
||||
std::vector<T> end{end_ptr, end_ptr + end_shape_[0]};
|
||||
std::vector<T> strides{strides_ptr, strides_ptr + stride_shape_[0]};
|
||||
shape_dim_output = SizeToInt(output_shape_.size());
|
||||
slice_len = SizeToInt(begin.size());
|
||||
FillEmptyDimsSTGrad<T>(&begin, &end, &strides, &input_shape_, &output_shape_);
|
||||
ParseStrideSliceGradMasksST<T>(cnode, &begin, &end, &strides, &input_shape_, output_shape_, shape_dim_output,
|
||||
ParseStrideSliceGradMasksST<T>(base_operator_, &begin, &end, &strides, &input_shape_, output_shape_, shape_dim_output,
|
||||
slice_len);
|
||||
FillEmptyDimsSTGrad<T>(&begin, &end, &strides, &input_shape_, &output_shape_);
|
||||
(void)std::transform(begin.begin(), begin.end(), std::back_inserter(begin_), [](const T &value) { return value; });
|
||||
|
@ -494,10 +508,6 @@ void StridedSliceV2GradCpuKernelMod::InitParams(const std::vector<kernel::Addres
|
|||
bool StridedSliceV2GradCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input can not be empty.";
|
||||
}
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
|
||||
bool ret = true;
|
||||
if (dtype_ == kNumberTypeInt32) {
|
||||
ret = LaunchKernel<int32_t>(inputs, outputs);
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <map>
|
||||
#include "plugin/device/cpu/kernel/nnacl/fp32_grad/strided_slice_grad.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
@ -30,7 +30,7 @@ namespace kernel {
|
|||
constexpr auto kStridedSliceV2Grad = "StridedSliceV2Grad";
|
||||
constexpr auto kUnknown = "Unknown";
|
||||
|
||||
class StridedSliceV2GradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
class StridedSliceV2GradCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
StridedSliceV2GradCpuKernelMod() = default;
|
||||
|
||||
|
@ -38,8 +38,10 @@ class StridedSliceV2GradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
|
||||
~StridedSliceV2GradCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
|
@ -68,16 +70,18 @@ class StridedSliceV2GradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
template <typename T>
|
||||
bool CalStridedSliceV2Grad(T *input, T *output);
|
||||
|
||||
BaseOperatorPtr base_operator_;
|
||||
std::vector<int> begin_;
|
||||
std::vector<int> end_;
|
||||
std::vector<int> strides_;
|
||||
std::vector<int> size_;
|
||||
ShapeVector input_shape_;
|
||||
int shape_dim_output{0};
|
||||
int slice_len{0};
|
||||
std::vector<size_t> input_element_num_;
|
||||
ShapeVector input_shape_;
|
||||
ShapeVector begin_shape_;
|
||||
ShapeVector end_shape_;
|
||||
ShapeVector stride_shape_;
|
||||
ShapeVector output_shape_;
|
||||
std::vector<size_t> output_element_num_;
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
TypeId dtype_grad_attr{kTypeUnknown};
|
||||
std::string kernel_type_{kUnknown};
|
||||
|
|
|
@ -311,22 +311,22 @@ static std::map<std::string, std::vector<KernelAttr>> support_list_map = {
|
|||
.AddOutputAttr(kNumberTypeComplex128)}}};
|
||||
|
||||
template <typename T>
|
||||
void ParseStrideSliceMasksST(const CNodePtr &kernel_node, std::vector<T> *begin, std::vector<T> *end,
|
||||
void ParseStrideSliceMasksST(const BaseOperatorPtr &base_operator, std::vector<T> *begin, std::vector<T> *end,
|
||||
std::vector<T> *stride, const ShapeVector &input_shape, size_t shape_dim_input,
|
||||
size_t slice_len) {
|
||||
std::vector<T> &_begin_attr = *begin;
|
||||
std::vector<T> &_end_attr = *end;
|
||||
std::vector<T> &_stride_attr = *stride;
|
||||
|
||||
auto begin_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrBeginMask);
|
||||
auto prim = base_operator->GetPrim();
|
||||
auto begin_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrBeginMask));
|
||||
auto begin_mask = Dec2Bin(begin_mask_int);
|
||||
auto end_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrEndMask);
|
||||
auto end_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrEndMask));
|
||||
auto end_mask = Dec2Bin(end_mask_int);
|
||||
auto ellipsis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrEllipsisMask);
|
||||
auto ellipsis_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrEllipsisMask));
|
||||
auto ellipsis_mask = Dec2Bin(ellipsis_mask_int);
|
||||
auto new_axis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrNewAxisMask);
|
||||
auto new_axis_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrNewAxisMask));
|
||||
auto new_axis_mask = Dec2Bin(new_axis_mask_int);
|
||||
auto shrink_axis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrShrinkAxisMask);
|
||||
auto shrink_axis_mask_int = GetValue<int64_t>(prim->GetAttr(kAttrShrinkAxisMask));
|
||||
auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_int);
|
||||
size_t i = 0;
|
||||
size_t j = 0;
|
||||
|
@ -380,14 +380,14 @@ void ParseStrideSliceMasksST(const CNodePtr &kernel_node, std::vector<T> *begin,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void FillEmptyDimsST(const CNodePtr &kernel_node, std::vector<T> *begin, std::vector<T> *end, std::vector<T> *stride,
|
||||
ShapeVector *input_shape) {
|
||||
void FillEmptyDimsST(const BaseOperatorPtr &base_operator, std::vector<T> *begin, std::vector<T> *end,
|
||||
std::vector<T> *stride, ShapeVector *input_shape) {
|
||||
std::vector<T> &_begin = *begin;
|
||||
std::vector<T> &_end = *end;
|
||||
std::vector<T> &_stride = *stride;
|
||||
auto &_input_shape = *input_shape;
|
||||
if (_begin.size() != _end.size() || _begin.size() != _stride.size() || _begin.size() > _input_shape.size()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << common::AnfAlgo::GetCNodeName(kernel_node)
|
||||
MS_LOG(EXCEPTION) << "For '" << base_operator->name()
|
||||
<< "', the length of 'begin', 'stride' and 'end' should be equal "
|
||||
"and less than or equal to the dimension of 'input_x', but got the length of 'begin': "
|
||||
<< _begin.size() << ", the length of 'stride': " << _stride.size()
|
||||
|
@ -416,27 +416,59 @@ void FillEmptyDimsST(const CNodePtr &kernel_node, std::vector<T> *begin, std::ve
|
|||
}
|
||||
} // namespace
|
||||
|
||||
void StridedSliceV2CpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
cnode_ptr_ = kernel_node;
|
||||
input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
return;
|
||||
bool StridedSliceV2CpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
base_operator_ = base_operator;
|
||||
kernel_name_ = base_operator->name();
|
||||
|
||||
if (inputs.size() != kStridedSliceV2InputsNum && inputs.size() != kStridedSliceV2DynamicInputsNum) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be " << kStridedSliceV2InputsNum
|
||||
<< " or " << kStridedSliceV2DynamicInputsNum << ", but got " << inputs.size();
|
||||
}
|
||||
// for begin, end, stride are const input
|
||||
auto begin = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, kAttrBegin);
|
||||
auto end = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, kAttrEnd);
|
||||
auto stride = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, kAttrStrides);
|
||||
InitSliceParam<int64_t>(kernel_node, &begin, &end, &stride);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kStridedSliceV2OutputsNum, kernel_name_);
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match.first) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
|
||||
dtype_ = inputs[kIndex0]->GetDtype();
|
||||
dtype_attr_ = inputs[kIndex1]->GetDtype();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int StridedSliceV2CpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
input_shape_ = inputs[kIndex0]->GetShapeVector();
|
||||
begin_shape_ = inputs[kIndex1]->GetShapeVector();
|
||||
end_shape_ = inputs[kIndex2]->GetShapeVector();
|
||||
stride_shape_ = inputs[kIndex3]->GetShapeVector();
|
||||
output_shape_ = outputs[kIndex0]->GetShapeVector();
|
||||
|
||||
if (inputs.size() == kStridedSliceV2DynamicInputsNum) {
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
auto prim = base_operator->GetPrim();
|
||||
auto begin = GetValue<std::vector<int64_t>>(prim->GetAttr(kAttrBegin));
|
||||
auto end = GetValue<std::vector<int64_t>>(prim->GetAttr(kAttrEnd));
|
||||
auto stride = GetValue<std::vector<int64_t>>(prim->GetAttr(kAttrStrides));
|
||||
InitSliceParam<int64_t>(base_operator, &begin, &end, &stride);
|
||||
|
||||
parallel_ = MatchParallelPattern();
|
||||
if (parallel_) {
|
||||
InitParallelParam();
|
||||
}
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
bool StridedSliceV2CpuKernelMod::MatchParallelPattern() {
|
||||
|
@ -492,8 +524,8 @@ void StridedSliceV2CpuKernelMod::InitParallelParam() {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void StridedSliceV2CpuKernelMod::InitSliceParam(const CNodePtr &kernel_node, std::vector<T> *begin, std::vector<T> *end,
|
||||
std::vector<T> *stride) {
|
||||
void StridedSliceV2CpuKernelMod::InitSliceParam(const BaseOperatorPtr &base_operator, std::vector<T> *begin,
|
||||
std::vector<T> *end, std::vector<T> *stride) {
|
||||
static const std::unordered_map<TypeId, std::pair<TypeIdC, int>> type_convert_map = {
|
||||
{kNumberTypeBool, {::kNumberTypeBool, sizeof(bool)}},
|
||||
{kNumberTypeInt8, {::kNumberTypeInt8, sizeof(int8_t)}},
|
||||
|
@ -521,16 +553,16 @@ void StridedSliceV2CpuKernelMod::InitSliceParam(const CNodePtr &kernel_node, std
|
|||
slice_param_.data_type = type_pair->second.first;
|
||||
auto input_shape_pad = input_shape_;
|
||||
shape_dim_input = input_shape_.size();
|
||||
FillEmptyDimsST<T>(kernel_node, begin, end, stride, &input_shape_pad);
|
||||
ParseStrideSliceMasksST<T>(kernel_node, begin, end, stride, input_shape_, shape_dim_input, slice_len);
|
||||
FillEmptyDimsST<T>(kernel_node, begin, end, stride, &input_shape_pad);
|
||||
FillEmptyDimsST<T>(base_operator, begin, end, stride, &input_shape_pad);
|
||||
ParseStrideSliceMasksST<T>(base_operator, begin, end, stride, input_shape_, shape_dim_input, slice_len);
|
||||
FillEmptyDimsST<T>(base_operator, begin, end, stride, &input_shape_pad);
|
||||
|
||||
std::vector<T> &_begin = *begin;
|
||||
std::vector<T> &_end = *end;
|
||||
std::vector<T> &_stride = *stride;
|
||||
for (size_t i = 0; i < DIMENSION_8D; i++) {
|
||||
slice_param_.in_shape_[i] = SizeToInt(input_shape_pad[i]);
|
||||
if (dtype_attr == kNumberTypeInt64) {
|
||||
if (dtype_attr_ == kNumberTypeInt64) {
|
||||
slice_param_.begins_[i] = LongToInt(_begin[i]);
|
||||
slice_param_.ends_[i] = LongToInt(_end[i]);
|
||||
slice_param_.strides_[i] = LongToInt(_stride[i]);
|
||||
|
@ -596,66 +628,43 @@ void StridedSliceV2CpuKernelMod::ParallelRun(const uint8_t *input_addr, uint8_t
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
bool StridedSliceV2CpuKernelMod::StridedSliceV2LaunchDynamicType(const std::vector<kernel::AddressPtr> &inputs) {
|
||||
auto cnode = cnode_ptr_.lock();
|
||||
auto begin_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 1);
|
||||
auto end_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 2);
|
||||
auto stride_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 3);
|
||||
if (begin_shape.size() != 1 || end_shape.size() != 1 || stride_shape.size() != 1) {
|
||||
void StridedSliceV2CpuKernelMod::StridedSliceV2LaunchDynamicType(const std::vector<kernel::AddressPtr> &inputs) {
|
||||
if (begin_shape_.size() != 1 || end_shape_.size() != 1 || stride_shape_.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dimension of 'begin', 'end', 'strides' should be equal "
|
||||
"to 1, but got the dimension of 'begin': "
|
||||
<< begin_shape.size() << ", the dimension of 'end': " << end_shape.size()
|
||||
<< ", and the dimension of 'strides': " << stride_shape.size();
|
||||
<< begin_shape_.size() << ", the dimension of 'end': " << end_shape_.size()
|
||||
<< ", and the dimension of 'strides': " << stride_shape_.size();
|
||||
}
|
||||
auto begin_ptr = static_cast<T *>(inputs[1]->addr);
|
||||
auto end_ptr = static_cast<T *>(inputs[2]->addr);
|
||||
auto strides_ptr = static_cast<T *>(inputs[3]->addr);
|
||||
std::vector<T> begin{begin_ptr, begin_ptr + begin_shape[0]};
|
||||
std::vector<T> end{end_ptr, end_ptr + end_shape[0]};
|
||||
std::vector<T> stride{strides_ptr, strides_ptr + stride_shape[0]};
|
||||
std::vector<T> begin{begin_ptr, begin_ptr + begin_shape_[0]};
|
||||
std::vector<T> end{end_ptr, end_ptr + end_shape_[0]};
|
||||
std::vector<T> stride{strides_ptr, strides_ptr + stride_shape_[0]};
|
||||
slice_len = begin.size();
|
||||
InitSliceParam<T>(cnode, &begin, &end, &stride);
|
||||
return true;
|
||||
InitSliceParam<T>(base_operator_, &begin, &end, &stride);
|
||||
}
|
||||
|
||||
bool StridedSliceV2CpuKernelMod::StridedSliceV2LaunchCal(const std::vector<kernel::AddressPtr> &inputs,
|
||||
void StridedSliceV2CpuKernelMod::StridedSliceV2LaunchCal(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.size() != kStridedSliceV2InputsNum && inputs.size() != kStridedSliceV2DynamicInputsNum) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be " << kStridedSliceV2InputsNum
|
||||
<< " or " << kStridedSliceV2DynamicInputsNum << ", but got " << inputs.size();
|
||||
// for begin, end, stride are not const input
|
||||
if (dtype_attr_ == kNumberTypeInt32) {
|
||||
StridedSliceV2LaunchDynamicType<int32_t>(inputs);
|
||||
} else {
|
||||
StridedSliceV2LaunchDynamicType<int64_t>(inputs);
|
||||
}
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kStridedSliceV2OutputsNum, kernel_name_);
|
||||
|
||||
auto cnode = cnode_ptr_.lock();
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||
if (input_num == kStridedSliceV2DynamicInputsNum) {
|
||||
bool flag = false;
|
||||
// for begin, end, stride are not const input
|
||||
dtype_attr = AnfAlgo::GetInputDeviceDataType(cnode, 1);
|
||||
if (dtype_attr == kNumberTypeInt32) {
|
||||
flag = StridedSliceV2LaunchDynamicType<int32_t>(inputs);
|
||||
} else {
|
||||
flag = StridedSliceV2LaunchDynamicType<int64_t>(inputs);
|
||||
}
|
||||
if (!flag) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the begin、end、stride is calculation error.";
|
||||
}
|
||||
|
||||
parallel_ = MatchParallelPattern();
|
||||
if (parallel_) {
|
||||
InitParallelParam();
|
||||
}
|
||||
parallel_ = MatchParallelPattern();
|
||||
if (parallel_) {
|
||||
InitParallelParam();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool StridedSliceV2CpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /* workspace */,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool ret = StridedSliceV2LaunchCal(inputs, outputs);
|
||||
if (ret != true) {
|
||||
MS_LOG(EXCEPTION) << "For StridedSliceV2 LaunchCal failed.";
|
||||
if (inputs.size() == kStridedSliceV2DynamicInputsNum) {
|
||||
StridedSliceV2LaunchCal(inputs, outputs);
|
||||
}
|
||||
auto input_addr = static_cast<uint8_t *>(inputs[0]->addr);
|
||||
auto output_addr = static_cast<uint8_t *>(outputs[0]->addr);
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "nnacl/fp32/strided_slice_fp32.h"
|
||||
|
@ -27,13 +28,15 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr auto kStridedSliceV2 = "StridedSliceV2";
|
||||
class StridedSliceV2CpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
class StridedSliceV2CpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
StridedSliceV2CpuKernelMod() = default;
|
||||
~StridedSliceV2CpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
|
@ -44,33 +47,38 @@ class StridedSliceV2CpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
enum ParallelStrategy { kOnSplitAxis, kOnOuter };
|
||||
|
||||
template <typename T>
|
||||
bool StridedSliceV2LaunchDynamicType(const std::vector<kernel::AddressPtr> &inputs);
|
||||
void StridedSliceV2LaunchDynamicType(const std::vector<kernel::AddressPtr> &inputs);
|
||||
|
||||
template <typename T>
|
||||
void InitSliceParam(const CNodePtr &kernel_node, std::vector<T> *begin, std::vector<T> *end, std::vector<T> *stride);
|
||||
void InitSliceParam(const BaseOperatorPtr &base_operator, std::vector<T> *begin, std::vector<T> *end,
|
||||
std::vector<T> *stride);
|
||||
bool MatchParallelPattern();
|
||||
void InitParallelParam();
|
||||
void ParallelRun(const uint8_t *input_addr, uint8_t *output_addr, int thread_num);
|
||||
|
||||
bool StridedSliceV2LaunchCal(const std::vector<kernel::AddressPtr> &inputs,
|
||||
void StridedSliceV2LaunchCal(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
common::Status RunTaskOnOuter(const uint8_t *input_addr, uint8_t *output_addr, int start_pos);
|
||||
common::Status RunTaskOnSplitAxis(const uint8_t *input_addr, uint8_t *output_addr, int start_pos);
|
||||
void ParseMasks(const CNodePtr &kernel_node);
|
||||
|
||||
TypeId dtype_;
|
||||
TypeId dtype_attr;
|
||||
TypeId dtype_attr_;
|
||||
int data_size_{4};
|
||||
int split_axis_{-1};
|
||||
int inner_{1};
|
||||
int outer_{1};
|
||||
int cal_num_per_thread_{1};
|
||||
bool parallel_{false};
|
||||
BaseOperatorPtr base_operator_;
|
||||
size_t inputs_num_;
|
||||
size_t shape_dim_input;
|
||||
size_t slice_len;
|
||||
ParallelStrategy parallel_strategy_{kOnSplitAxis};
|
||||
ShapeVector input_shape_;
|
||||
ShapeVector output_shape_;
|
||||
ShapeVector begin_shape_;
|
||||
ShapeVector end_shape_;
|
||||
ShapeVector stride_shape_;
|
||||
StridedSliceParameter slice_param_;
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
Loading…
Reference in New Issue