!44693 回退代码:输入不转属性算子stridedslice_gpu

Merge pull request !44693 from Yanzhi_YI/fix_bug_strided_slice
This commit is contained in:
i-robot 2022-10-28 03:36:29 +00:00 committed by Gitee
commit dd512e4d7e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 36 additions and 7 deletions

View File

@ -34,7 +34,13 @@ class StridedSliceGpuCommon {
StridedSliceGpuCommon() : null_output_(false) {}
~StridedSliceGpuCommon() = default;
void CollectInfo(const BaseOperatorPtr &base_operator) {
void CollectInfo(const BaseOperatorPtr &base_operator, bool is_dynamic_attr_ = false) {
if (!is_dynamic_attr_) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::StridedSlice>(base_operator);
begin_ = kernel_ptr->get_begin();
end_ = kernel_ptr->get_end();
strides_ = kernel_ptr->get_strides();
}
auto shape_tmp = Convert2Long(input_shape_);
FillEmptyDims(base_operator, &begin_, &end_, &strides_, &shape_tmp);
input_shape_ = Convert2SizeT(shape_tmp);

View File

@ -22,6 +22,7 @@
namespace mindspore {
namespace kernel {
constexpr size_t DynamicInputNum = 4;
template <typename T>
using Complex = mindspore::utils::Complex<T>;
@ -56,6 +57,10 @@ int StridedSliceGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
return ret;
}
if (inputs.size() == DynamicInputNum) {
is_dynamic_attr_ = true;
}
auto shape_signed = inputs[0]->GetShapeVector();
input_shape_ = Convert2SizeTClipNeg(shape_signed);
null_output_ = CHECK_SHAPE_NULL(input_shape_, kernel_name_, "input");
@ -66,12 +71,12 @@ int StridedSliceGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input cannot be greater than " << MAX_DIMS
<< ", but got " << input_shape_.size();
}
GetDynamicAttrIntValue(inputs, kBeginIndex_, inputsOnHost, kernel_name_, &begin_);
GetDynamicAttrIntValue(inputs, kEndIndex_, inputsOnHost, kernel_name_, &end_);
GetDynamicAttrIntValue(inputs, kStrideIndex_, inputsOnHost, kernel_name_, &strides_);
CollectInfo(base_operator);
if (is_dynamic_attr_) {
GetDynamicAttrIntValue(inputs, kBeginIndex_, inputsOnHost, kernel_name_, &begin_);
GetDynamicAttrIntValue(inputs, kEndIndex_, inputsOnHost, kernel_name_, &end_);
GetDynamicAttrIntValue(inputs, kStrideIndex_, inputsOnHost, kernel_name_, &strides_);
}
CollectInfo(base_operator, is_dynamic_attr_);
return ret;
}
@ -89,6 +94,21 @@ int StridedSliceGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
&StridedSliceGpuKernelMod::LaunchKernel<TYPE_1, TYPE_2>
std::vector<std::pair<KernelAttr, StridedSliceGpuKernelMod::StridedSliceFunc>> StridedSliceGpuKernelMod::func_list_ = {
{STRIDEDSLICE_GPU_REG(kNumberTypeFloat64, double)},
{STRIDEDSLICE_GPU_REG(kNumberTypeFloat32, float)},
{STRIDEDSLICE_GPU_REG(kNumberTypeFloat16, half)},
{STRIDEDSLICE_GPU_REG(kNumberTypeInt64, int64_t)},
{STRIDEDSLICE_GPU_REG(kNumberTypeInt32, int32_t)},
{STRIDEDSLICE_GPU_REG(kNumberTypeInt16, int16_t)},
{STRIDEDSLICE_GPU_REG(kNumberTypeInt8, int8_t)},
{STRIDEDSLICE_GPU_REG(kNumberTypeUInt64, uint64_t)},
{STRIDEDSLICE_GPU_REG(kNumberTypeUInt32, uint32_t)},
{STRIDEDSLICE_GPU_REG(kNumberTypeUInt16, uint16_t)},
{STRIDEDSLICE_GPU_REG(kNumberTypeUInt8, uint8_t)},
{STRIDEDSLICE_GPU_REG(kNumberTypeBool, bool)},
{STRIDEDSLICE_GPU_REG(kNumberTypeComplex64, Complex<float>)},
{STRIDEDSLICE_GPU_REG(kNumberTypeComplex128, Complex<double>)},
{STRIDEDSLICE_DYNAMIC_GPU_REG(kNumberTypeFloat64, kNumberTypeInt64, double, int64_t)},
{STRIDEDSLICE_DYNAMIC_GPU_REG(kNumberTypeFloat32, kNumberTypeInt64, float, int64_t)},
{STRIDEDSLICE_DYNAMIC_GPU_REG(kNumberTypeFloat16, kNumberTypeInt64, half, int64_t)},

View File

@ -59,6 +59,8 @@ class StridedSliceGpuKernelMod : public NativeGpuKernelMod, public StridedSliceG
StridedSliceFunc kernel_func_;
bool is_null_input_{false};
bool is_dynamic_attr_{false};
bool get_dynamic_attr_value_{false};
static constexpr size_t kBeginIndex_{1};
static constexpr size_t kEndIndex_{2};
static constexpr size_t kStrideIndex_{3};

View File

@ -77,6 +77,7 @@ RER_GPU_STATIC_CONST_TO_ATTR(kSpaceToBatchOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kSparseApplyAdagradOpName, 2);
RER_GPU_STATIC_CONST_TO_ATTR(kSparseGatherV2OpName, 2);
RER_GPU_STATIC_CONST_TO_ATTR(kStridedSliceAssignOpName, 1, 2, 3);
RER_GPU_STATIC_CONST_TO_ATTR(kStridedSliceOpName, 1, 2, 3);
RER_GPU_STATIC_CONST_TO_ATTR(kSubscalarOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kTensorCopySlicesOpName, 2, 3, 4);
RER_GPU_STATIC_CONST_TO_ATTR(kTileOpName, 1);