diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/strided_slice_v2_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/strided_slice_v2_grad_cpu_kernel.cc index 965026cb6a0..51c06501abb 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/strided_slice_v2_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/strided_slice_v2_grad_cpu_kernel.cc @@ -259,21 +259,24 @@ std::vector GradDec2Bin(const int64_t &mask) { } template -void ParseStrideSliceGradMasksST(const CNodePtr &kernel_node, std::vector *begin, std::vector *end, +void ParseStrideSliceGradMasksST(const BaseOperatorPtr &base_operator, std::vector *begin, std::vector *end, std::vector *stride, ShapeVector *input_shape, const ShapeVector output_shape, int shape_dim_output, int slice_len) { std::vector &_begin_attr = *begin; std::vector &_end_attr = *end; std::vector &_stride_attr = *stride; - auto begin_mask_int = common::AnfAlgo::GetNodeAttr(kernel_node, kAttrBeginMask); + auto prim = base_operator->GetPrim(); + MS_EXCEPTION_IF_NULL(prim); + + auto begin_mask_int = GetValue(prim->GetAttr(kAttrBeginMask)); auto begin_mask = GradDec2Bin(begin_mask_int); - auto end_mask_int = common::AnfAlgo::GetNodeAttr(kernel_node, kAttrEndMask); + auto end_mask_int = GetValue(prim->GetAttr(kAttrEndMask)); auto end_mask = GradDec2Bin(end_mask_int); - auto ellipsis_mask_int = common::AnfAlgo::GetNodeAttr(kernel_node, kAttrEllipsisMask); + auto ellipsis_mask_int = GetValue(prim->GetAttr(kAttrEllipsisMask)); auto ellipsis_mask = GradDec2Bin(ellipsis_mask_int); - auto new_axis_mask_int = common::AnfAlgo::GetNodeAttr(kernel_node, kAttrNewAxisMask); + auto new_axis_mask_int = GetValue(prim->GetAttr(kAttrNewAxisMask)); auto new_axis_mask = GradDec2Bin(new_axis_mask_int); - auto shrink_axis_mask_int = common::AnfAlgo::GetNodeAttr(kernel_node, kAttrShrinkAxisMask); + auto shrink_axis_mask_int = GetValue(prim->GetAttr(kAttrShrinkAxisMask)); auto shrink_axis_mask = GradDec2Bin(shrink_axis_mask_int); int i = 0; int j = 0; @@ -366,33 +369,54 @@ void FillEmptyDimsSTGrad(std::vector *begin, std::vector *end, std::vector } } // namespace -void StridedSliceV2GradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); - cnode_ptr_ = kernel_node; - ClearVectors(); - auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex4); - if (input_shape.size() > kStridedSliceV2GradMaxInputShapeSize) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input tensor must be 8D or lower, but got " - << input_shape.size() << "D."; - } - output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0); - dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex4); - dtype_grad_attr = AnfAlgo::GetInputDeviceDataType(kernel_node, 1); - size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num == kStridedSliceV2GradDynamicInputsNum) { // Dynamic Shape - return; - } - // in the case that begin, end, size, stride are const value. - std::vector begin_me = common::AnfAlgo::GetNodeAttr>(kernel_node, BEGIN); - (void)std::transform(begin_me.begin(), begin_me.end(), std::back_inserter(begin_), - [](const int64_t &value) { return LongToInt(value); }); - auto prim = common::AnfAlgo::GetCNodePrimitive(kernel_node); - MS_EXCEPTION_IF_NULL(prim); - auto strides = prim->GetAttr(STRIDES); +bool StridedSliceV2GradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, + const std::vector &inputs, + const std::vector &outputs) { + MS_EXCEPTION_IF_NULL(base_operator); + base_operator_ = base_operator; + kernel_name_ = base_operator->name(); - std::vector strides_me = common::AnfAlgo::GetNodeAttr>(kernel_node, STRIDES); - std::vector end_me = common::AnfAlgo::GetNodeAttr>(kernel_node, END); + if (inputs.empty()) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input can not be empty."; + } + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_); + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match.first) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; + return false; + } + + dtype_ = inputs[kIndex4]->GetDtype(); + dtype_grad_attr = inputs[kIndex1]->GetDtype(); + + return true; +} + +int StridedSliceV2GradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, + const std::vector &inputs, + const std::vector &outputs, + const std::map &) { + if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) { + return ret; + } + + ClearVectors(); + begin_shape_ = inputs_[kIndex1]->GetShapeVector(); + end_shape_ = inputs_[kIndex2]->GetShapeVector(); + stride_shape_ = inputs_[kIndex3]->GetShapeVector(); + input_shape_ = inputs[kIndex4]->GetShapeVector(); + output_shape_ = outputs[kIndex0]->GetShapeVector(); + + if (inputs.size() == kStridedSliceV2GradDynamicInputsNum) { // Dynamic Shape + return KRET_OK; + } + + auto prim = base_operator->GetPrim(); + MS_EXCEPTION_IF_NULL(prim); + std::vector strides_me = GetValue>(prim->GetAttr(STRIDES)); + std::vector end_me = GetValue>(prim->GetAttr(END)); (void)std::transform(strides_me.begin(), strides_me.end(), std::back_inserter(strides_), [](const int64_t &value) { return LongToInt(value); }); (void)std::transform(end_me.begin(), end_me.end(), std::back_inserter(end_), @@ -405,6 +429,7 @@ void StridedSliceV2GradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { } FormatArgs(true); ExpandAllMemberDims(kStridedSliceV2GradMaxInputShapeSize); + return KRET_OK; } void StridedSliceV2GradCpuKernelMod::ClearVectors() { @@ -412,8 +437,6 @@ void StridedSliceV2GradCpuKernelMod::ClearVectors() { size_.clear(); strides_.clear(); end_.clear(); - input_element_num_.clear(); - output_element_num_.clear(); input_shape_.clear(); output_shape_.clear(); } @@ -449,33 +472,24 @@ void StridedSliceV2GradCpuKernelMod::ExpandAllMemberDims(size_t expand_dims) { // init for dynamic shape template void StridedSliceV2GradCpuKernelMod::InitParams(const std::vector &inputs) { - auto cnode = cnode_ptr_.lock(); - ClearVectors(); - output_shape_ = common::AnfAlgo::GetOutputInferShape(cnode, 0); - std::string kernel_name = common::AnfAlgo::GetCNodeName(cnode); - auto begin_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 1); - auto begin_ptr = static_cast(inputs[1]->addr); - std::vector begin{begin_ptr, begin_ptr + begin_shape[0]}; - - auto end_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, kIndex2); - auto stride_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, kIndex3); - if (begin_shape.size() != 1 || end_shape.size() != 1 || stride_shape.size() != 1) { + if (begin_shape_.size() != 1 || end_shape_.size() != 1 || stride_shape_.size() != 1) { MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimensions of 'begin', 'end', 'strides' must be 1, " "but got the dimension of 'begin': " - << begin_shape.size() << ", the dimension of 'end': " << end_shape.size() - << ", and the dimension of 'strides': " << stride_shape.size(); + << begin_shape_.size() << ", the dimension of 'end': " << end_shape_.size() + << ", and the dimension of 'strides': " << stride_shape_.size(); } + auto begin_ptr = static_cast(inputs[1]->addr); + std::vector begin{begin_ptr, begin_ptr + begin_shape_[0]}; auto end_ptr = static_cast(inputs[kIndex2]->addr); auto strides_ptr = static_cast(inputs[kIndex3]->addr); - std::vector end{end_ptr, end_ptr + end_shape[0]}; - std::vector strides{strides_ptr, strides_ptr + stride_shape[0]}; - input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, kIndex4); + std::vector end{end_ptr, end_ptr + end_shape_[0]}; + std::vector strides{strides_ptr, strides_ptr + stride_shape_[0]}; shape_dim_output = SizeToInt(output_shape_.size()); slice_len = SizeToInt(begin.size()); FillEmptyDimsSTGrad(&begin, &end, &strides, &input_shape_, &output_shape_); - ParseStrideSliceGradMasksST(cnode, &begin, &end, &strides, &input_shape_, output_shape_, shape_dim_output, + ParseStrideSliceGradMasksST(base_operator_, &begin, &end, &strides, &input_shape_, output_shape_, shape_dim_output, slice_len); FillEmptyDimsSTGrad(&begin, &end, &strides, &input_shape_, &output_shape_); (void)std::transform(begin.begin(), begin.end(), std::back_inserter(begin_), [](const T &value) { return value; }); @@ -494,10 +508,6 @@ void StridedSliceV2GradCpuKernelMod::InitParams(const std::vector &inputs, const std::vector &, const std::vector &outputs) { - if (inputs.empty()) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input can not be empty."; - } - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_); bool ret = true; if (dtype_ == kNumberTypeInt32) { ret = LaunchKernel(inputs, outputs); diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/strided_slice_v2_grad_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/strided_slice_v2_grad_cpu_kernel.h index 98ebb73a807..97ad1d1c117 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/strided_slice_v2_grad_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/strided_slice_v2_grad_cpu_kernel.h @@ -20,7 +20,7 @@ #include #include #include - +#include #include "plugin/device/cpu/kernel/nnacl/fp32_grad/strided_slice_grad.h" #include "plugin/device/cpu/kernel/cpu_kernel.h" #include "plugin/factory/ms_factory.h" @@ -30,7 +30,7 @@ namespace kernel { constexpr auto kStridedSliceV2Grad = "StridedSliceV2Grad"; constexpr auto kUnknown = "Unknown"; -class StridedSliceV2GradCpuKernelMod : public DeprecatedNativeCpuKernelMod { +class StridedSliceV2GradCpuKernelMod : public NativeCpuKernelMod { public: StridedSliceV2GradCpuKernelMod() = default; @@ -38,8 +38,10 @@ class StridedSliceV2GradCpuKernelMod : public DeprecatedNativeCpuKernelMod { ~StridedSliceV2GradCpuKernelMod() override = default; - void InitKernel(const CNodePtr &kernel_node) override; - + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, const std::map &) override; bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; @@ -68,16 +70,18 @@ class StridedSliceV2GradCpuKernelMod : public DeprecatedNativeCpuKernelMod { template bool CalStridedSliceV2Grad(T *input, T *output); + BaseOperatorPtr base_operator_; std::vector begin_; std::vector end_; std::vector strides_; std::vector size_; - ShapeVector input_shape_; int shape_dim_output{0}; int slice_len{0}; - std::vector input_element_num_; + ShapeVector input_shape_; + ShapeVector begin_shape_; + ShapeVector end_shape_; + ShapeVector stride_shape_; ShapeVector output_shape_; - std::vector output_element_num_; TypeId dtype_{kTypeUnknown}; TypeId dtype_grad_attr{kTypeUnknown}; std::string kernel_type_{kUnknown}; diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/stridedslice_v2_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/stridedslice_v2_cpu_kernel.cc index 0e63de0c9be..1d3d7a019a1 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/stridedslice_v2_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/stridedslice_v2_cpu_kernel.cc @@ -311,22 +311,22 @@ static std::map> support_list_map = { .AddOutputAttr(kNumberTypeComplex128)}}}; template -void ParseStrideSliceMasksST(const CNodePtr &kernel_node, std::vector *begin, std::vector *end, +void ParseStrideSliceMasksST(const BaseOperatorPtr &base_operator, std::vector *begin, std::vector *end, std::vector *stride, const ShapeVector &input_shape, size_t shape_dim_input, size_t slice_len) { std::vector &_begin_attr = *begin; std::vector &_end_attr = *end; std::vector &_stride_attr = *stride; - - auto begin_mask_int = common::AnfAlgo::GetNodeAttr(kernel_node, kAttrBeginMask); + auto prim = base_operator->GetPrim(); + auto begin_mask_int = GetValue(prim->GetAttr(kAttrBeginMask)); auto begin_mask = Dec2Bin(begin_mask_int); - auto end_mask_int = common::AnfAlgo::GetNodeAttr(kernel_node, kAttrEndMask); + auto end_mask_int = GetValue(prim->GetAttr(kAttrEndMask)); auto end_mask = Dec2Bin(end_mask_int); - auto ellipsis_mask_int = common::AnfAlgo::GetNodeAttr(kernel_node, kAttrEllipsisMask); + auto ellipsis_mask_int = GetValue(prim->GetAttr(kAttrEllipsisMask)); auto ellipsis_mask = Dec2Bin(ellipsis_mask_int); - auto new_axis_mask_int = common::AnfAlgo::GetNodeAttr(kernel_node, kAttrNewAxisMask); + auto new_axis_mask_int = GetValue(prim->GetAttr(kAttrNewAxisMask)); auto new_axis_mask = Dec2Bin(new_axis_mask_int); - auto shrink_axis_mask_int = common::AnfAlgo::GetNodeAttr(kernel_node, kAttrShrinkAxisMask); + auto shrink_axis_mask_int = GetValue(prim->GetAttr(kAttrShrinkAxisMask)); auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_int); size_t i = 0; size_t j = 0; @@ -380,14 +380,14 @@ void ParseStrideSliceMasksST(const CNodePtr &kernel_node, std::vector *begin, } template -void FillEmptyDimsST(const CNodePtr &kernel_node, std::vector *begin, std::vector *end, std::vector *stride, - ShapeVector *input_shape) { +void FillEmptyDimsST(const BaseOperatorPtr &base_operator, std::vector *begin, std::vector *end, + std::vector *stride, ShapeVector *input_shape) { std::vector &_begin = *begin; std::vector &_end = *end; std::vector &_stride = *stride; auto &_input_shape = *input_shape; if (_begin.size() != _end.size() || _begin.size() != _stride.size() || _begin.size() > _input_shape.size()) { - MS_LOG(EXCEPTION) << "For '" << common::AnfAlgo::GetCNodeName(kernel_node) + MS_LOG(EXCEPTION) << "For '" << base_operator->name() << "', the length of 'begin', 'stride' and 'end' should be equal " "and less than or equal to the dimension of 'input_x', but got the length of 'begin': " << _begin.size() << ", the length of 'stride': " << _stride.size() @@ -416,27 +416,59 @@ void FillEmptyDimsST(const CNodePtr &kernel_node, std::vector *begin, std::ve } } // namespace -void StridedSliceV2CpuKernelMod::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); - cnode_ptr_ = kernel_node; - input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0); - dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); - size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - return; +bool StridedSliceV2CpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + MS_EXCEPTION_IF_NULL(base_operator); + base_operator_ = base_operator; + kernel_name_ = base_operator->name(); + + if (inputs.size() != kStridedSliceV2InputsNum && inputs.size() != kStridedSliceV2DynamicInputsNum) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be " << kStridedSliceV2InputsNum + << " or " << kStridedSliceV2DynamicInputsNum << ", but got " << inputs.size(); } - // for begin, end, stride are const input - auto begin = common::AnfAlgo::GetNodeAttr>(kernel_node, kAttrBegin); - auto end = common::AnfAlgo::GetNodeAttr>(kernel_node, kAttrEnd); - auto stride = common::AnfAlgo::GetNodeAttr>(kernel_node, kAttrStrides); - InitSliceParam(kernel_node, &begin, &end, &stride); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kStridedSliceV2OutputsNum, kernel_name_); + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match.first) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; + return false; + } + + dtype_ = inputs[kIndex0]->GetDtype(); + dtype_attr_ = inputs[kIndex1]->GetDtype(); + + return true; +} + +int StridedSliceV2CpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &) { + if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) { + return ret; + } + + input_shape_ = inputs[kIndex0]->GetShapeVector(); + begin_shape_ = inputs[kIndex1]->GetShapeVector(); + end_shape_ = inputs[kIndex2]->GetShapeVector(); + stride_shape_ = inputs[kIndex3]->GetShapeVector(); + output_shape_ = outputs[kIndex0]->GetShapeVector(); + + if (inputs.size() == kStridedSliceV2DynamicInputsNum) { + return KRET_OK; + } + + auto prim = base_operator->GetPrim(); + auto begin = GetValue>(prim->GetAttr(kAttrBegin)); + auto end = GetValue>(prim->GetAttr(kAttrEnd)); + auto stride = GetValue>(prim->GetAttr(kAttrStrides)); + InitSliceParam(base_operator, &begin, &end, &stride); parallel_ = MatchParallelPattern(); if (parallel_) { InitParallelParam(); } + return KRET_OK; } bool StridedSliceV2CpuKernelMod::MatchParallelPattern() { @@ -492,8 +524,8 @@ void StridedSliceV2CpuKernelMod::InitParallelParam() { } template -void StridedSliceV2CpuKernelMod::InitSliceParam(const CNodePtr &kernel_node, std::vector *begin, std::vector *end, - std::vector *stride) { +void StridedSliceV2CpuKernelMod::InitSliceParam(const BaseOperatorPtr &base_operator, std::vector *begin, + std::vector *end, std::vector *stride) { static const std::unordered_map> type_convert_map = { {kNumberTypeBool, {::kNumberTypeBool, sizeof(bool)}}, {kNumberTypeInt8, {::kNumberTypeInt8, sizeof(int8_t)}}, @@ -521,16 +553,16 @@ void StridedSliceV2CpuKernelMod::InitSliceParam(const CNodePtr &kernel_node, std slice_param_.data_type = type_pair->second.first; auto input_shape_pad = input_shape_; shape_dim_input = input_shape_.size(); - FillEmptyDimsST(kernel_node, begin, end, stride, &input_shape_pad); - ParseStrideSliceMasksST(kernel_node, begin, end, stride, input_shape_, shape_dim_input, slice_len); - FillEmptyDimsST(kernel_node, begin, end, stride, &input_shape_pad); + FillEmptyDimsST(base_operator, begin, end, stride, &input_shape_pad); + ParseStrideSliceMasksST(base_operator, begin, end, stride, input_shape_, shape_dim_input, slice_len); + FillEmptyDimsST(base_operator, begin, end, stride, &input_shape_pad); std::vector &_begin = *begin; std::vector &_end = *end; std::vector &_stride = *stride; for (size_t i = 0; i < DIMENSION_8D; i++) { slice_param_.in_shape_[i] = SizeToInt(input_shape_pad[i]); - if (dtype_attr == kNumberTypeInt64) { + if (dtype_attr_ == kNumberTypeInt64) { slice_param_.begins_[i] = LongToInt(_begin[i]); slice_param_.ends_[i] = LongToInt(_end[i]); slice_param_.strides_[i] = LongToInt(_stride[i]); @@ -596,66 +628,43 @@ void StridedSliceV2CpuKernelMod::ParallelRun(const uint8_t *input_addr, uint8_t } template -bool StridedSliceV2CpuKernelMod::StridedSliceV2LaunchDynamicType(const std::vector &inputs) { - auto cnode = cnode_ptr_.lock(); - auto begin_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 1); - auto end_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 2); - auto stride_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 3); - if (begin_shape.size() != 1 || end_shape.size() != 1 || stride_shape.size() != 1) { +void StridedSliceV2CpuKernelMod::StridedSliceV2LaunchDynamicType(const std::vector &inputs) { + if (begin_shape_.size() != 1 || end_shape_.size() != 1 || stride_shape_.size() != 1) { MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'begin', 'end', 'strides' should be equal " "to 1, but got the dimension of 'begin': " - << begin_shape.size() << ", the dimension of 'end': " << end_shape.size() - << ", and the dimension of 'strides': " << stride_shape.size(); + << begin_shape_.size() << ", the dimension of 'end': " << end_shape_.size() + << ", and the dimension of 'strides': " << stride_shape_.size(); } auto begin_ptr = static_cast(inputs[1]->addr); auto end_ptr = static_cast(inputs[2]->addr); auto strides_ptr = static_cast(inputs[3]->addr); - std::vector begin{begin_ptr, begin_ptr + begin_shape[0]}; - std::vector end{end_ptr, end_ptr + end_shape[0]}; - std::vector stride{strides_ptr, strides_ptr + stride_shape[0]}; + std::vector begin{begin_ptr, begin_ptr + begin_shape_[0]}; + std::vector end{end_ptr, end_ptr + end_shape_[0]}; + std::vector stride{strides_ptr, strides_ptr + stride_shape_[0]}; slice_len = begin.size(); - InitSliceParam(cnode, &begin, &end, &stride); - return true; + InitSliceParam(base_operator_, &begin, &end, &stride); } -bool StridedSliceV2CpuKernelMod::StridedSliceV2LaunchCal(const std::vector &inputs, +void StridedSliceV2CpuKernelMod::StridedSliceV2LaunchCal(const std::vector &inputs, const std::vector &outputs) { - if (inputs.size() != kStridedSliceV2InputsNum && inputs.size() != kStridedSliceV2DynamicInputsNum) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be " << kStridedSliceV2InputsNum - << " or " << kStridedSliceV2DynamicInputsNum << ", but got " << inputs.size(); + // for begin, end, stride are not const input + if (dtype_attr_ == kNumberTypeInt32) { + StridedSliceV2LaunchDynamicType(inputs); + } else { + StridedSliceV2LaunchDynamicType(inputs); } - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kStridedSliceV2OutputsNum, kernel_name_); - - auto cnode = cnode_ptr_.lock(); - size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode); - if (input_num == kStridedSliceV2DynamicInputsNum) { - bool flag = false; - // for begin, end, stride are not const input - dtype_attr = AnfAlgo::GetInputDeviceDataType(cnode, 1); - if (dtype_attr == kNumberTypeInt32) { - flag = StridedSliceV2LaunchDynamicType(inputs); - } else { - flag = StridedSliceV2LaunchDynamicType(inputs); - } - if (!flag) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the begin、end、stride is calculation error."; - } - - parallel_ = MatchParallelPattern(); - if (parallel_) { - InitParallelParam(); - } + parallel_ = MatchParallelPattern(); + if (parallel_) { + InitParallelParam(); } - return true; } bool StridedSliceV2CpuKernelMod::Launch(const std::vector &inputs, const std::vector & /* workspace */, const std::vector &outputs) { - bool ret = StridedSliceV2LaunchCal(inputs, outputs); - if (ret != true) { - MS_LOG(EXCEPTION) << "For StridedSliceV2 LaunchCal failed."; + if (inputs.size() == kStridedSliceV2DynamicInputsNum) { + StridedSliceV2LaunchCal(inputs, outputs); } auto input_addr = static_cast(inputs[0]->addr); auto output_addr = static_cast(outputs[0]->addr); diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/stridedslice_v2_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/stridedslice_v2_cpu_kernel.h index d985202a2b0..74a1f932647 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/stridedslice_v2_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/stridedslice_v2_cpu_kernel.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "plugin/device/cpu/kernel/cpu_kernel.h" #include "plugin/factory/ms_factory.h" #include "nnacl/fp32/strided_slice_fp32.h" @@ -27,13 +28,15 @@ namespace mindspore { namespace kernel { constexpr auto kStridedSliceV2 = "StridedSliceV2"; -class StridedSliceV2CpuKernelMod : public DeprecatedNativeCpuKernelMod { +class StridedSliceV2CpuKernelMod : public NativeCpuKernelMod { public: StridedSliceV2CpuKernelMod() = default; ~StridedSliceV2CpuKernelMod() override = default; - void InitKernel(const CNodePtr &kernel_node) override; - + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, const std::map &) override; bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; @@ -44,33 +47,38 @@ class StridedSliceV2CpuKernelMod : public DeprecatedNativeCpuKernelMod { enum ParallelStrategy { kOnSplitAxis, kOnOuter }; template - bool StridedSliceV2LaunchDynamicType(const std::vector &inputs); + void StridedSliceV2LaunchDynamicType(const std::vector &inputs); template - void InitSliceParam(const CNodePtr &kernel_node, std::vector *begin, std::vector *end, std::vector *stride); + void InitSliceParam(const BaseOperatorPtr &base_operator, std::vector *begin, std::vector *end, + std::vector *stride); bool MatchParallelPattern(); void InitParallelParam(); void ParallelRun(const uint8_t *input_addr, uint8_t *output_addr, int thread_num); - bool StridedSliceV2LaunchCal(const std::vector &inputs, + void StridedSliceV2LaunchCal(const std::vector &inputs, const std::vector &outputs); common::Status RunTaskOnOuter(const uint8_t *input_addr, uint8_t *output_addr, int start_pos); common::Status RunTaskOnSplitAxis(const uint8_t *input_addr, uint8_t *output_addr, int start_pos); - void ParseMasks(const CNodePtr &kernel_node); TypeId dtype_; - TypeId dtype_attr; + TypeId dtype_attr_; int data_size_{4}; int split_axis_{-1}; int inner_{1}; int outer_{1}; int cal_num_per_thread_{1}; bool parallel_{false}; + BaseOperatorPtr base_operator_; + size_t inputs_num_; size_t shape_dim_input; size_t slice_len; ParallelStrategy parallel_strategy_{kOnSplitAxis}; ShapeVector input_shape_; ShapeVector output_shape_; + ShapeVector begin_shape_; + ShapeVector end_shape_; + ShapeVector stride_shape_; StridedSliceParameter slice_param_; }; } // namespace kernel