forked from mindspore-Ecosystem/mindspore
!44693 回退代码:输入不转属性算子stridedslice_gpu
Merge pull request !44693 from Yanzhi_YI/fix_bug_strided_slice
This commit is contained in:
commit
dd512e4d7e
|
@ -34,7 +34,13 @@ class StridedSliceGpuCommon {
|
|||
StridedSliceGpuCommon() : null_output_(false) {}
|
||||
~StridedSliceGpuCommon() = default;
|
||||
|
||||
void CollectInfo(const BaseOperatorPtr &base_operator) {
|
||||
void CollectInfo(const BaseOperatorPtr &base_operator, bool is_dynamic_attr_ = false) {
|
||||
if (!is_dynamic_attr_) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::StridedSlice>(base_operator);
|
||||
begin_ = kernel_ptr->get_begin();
|
||||
end_ = kernel_ptr->get_end();
|
||||
strides_ = kernel_ptr->get_strides();
|
||||
}
|
||||
auto shape_tmp = Convert2Long(input_shape_);
|
||||
FillEmptyDims(base_operator, &begin_, &end_, &strides_, &shape_tmp);
|
||||
input_shape_ = Convert2SizeT(shape_tmp);
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t DynamicInputNum = 4;
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
|
||||
|
@ -56,6 +57,10 @@ int StridedSliceGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
|
|||
return ret;
|
||||
}
|
||||
|
||||
if (inputs.size() == DynamicInputNum) {
|
||||
is_dynamic_attr_ = true;
|
||||
}
|
||||
|
||||
auto shape_signed = inputs[0]->GetShapeVector();
|
||||
input_shape_ = Convert2SizeTClipNeg(shape_signed);
|
||||
null_output_ = CHECK_SHAPE_NULL(input_shape_, kernel_name_, "input");
|
||||
|
@ -66,12 +71,12 @@ int StridedSliceGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
|
|||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input cannot be greater than " << MAX_DIMS
|
||||
<< ", but got " << input_shape_.size();
|
||||
}
|
||||
|
||||
GetDynamicAttrIntValue(inputs, kBeginIndex_, inputsOnHost, kernel_name_, &begin_);
|
||||
GetDynamicAttrIntValue(inputs, kEndIndex_, inputsOnHost, kernel_name_, &end_);
|
||||
GetDynamicAttrIntValue(inputs, kStrideIndex_, inputsOnHost, kernel_name_, &strides_);
|
||||
|
||||
CollectInfo(base_operator);
|
||||
if (is_dynamic_attr_) {
|
||||
GetDynamicAttrIntValue(inputs, kBeginIndex_, inputsOnHost, kernel_name_, &begin_);
|
||||
GetDynamicAttrIntValue(inputs, kEndIndex_, inputsOnHost, kernel_name_, &end_);
|
||||
GetDynamicAttrIntValue(inputs, kStrideIndex_, inputsOnHost, kernel_name_, &strides_);
|
||||
}
|
||||
CollectInfo(base_operator, is_dynamic_attr_);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
@ -89,6 +94,21 @@ int StridedSliceGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
|
|||
&StridedSliceGpuKernelMod::LaunchKernel<TYPE_1, TYPE_2>
|
||||
|
||||
std::vector<std::pair<KernelAttr, StridedSliceGpuKernelMod::StridedSliceFunc>> StridedSliceGpuKernelMod::func_list_ = {
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeFloat64, double)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeFloat32, float)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeFloat16, half)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeInt64, int64_t)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeInt32, int32_t)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeInt16, int16_t)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeInt8, int8_t)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeUInt64, uint64_t)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeUInt32, uint32_t)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeUInt16, uint16_t)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeUInt8, uint8_t)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeBool, bool)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeComplex64, Complex<float>)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeComplex128, Complex<double>)},
|
||||
|
||||
{STRIDEDSLICE_DYNAMIC_GPU_REG(kNumberTypeFloat64, kNumberTypeInt64, double, int64_t)},
|
||||
{STRIDEDSLICE_DYNAMIC_GPU_REG(kNumberTypeFloat32, kNumberTypeInt64, float, int64_t)},
|
||||
{STRIDEDSLICE_DYNAMIC_GPU_REG(kNumberTypeFloat16, kNumberTypeInt64, half, int64_t)},
|
||||
|
|
|
@ -59,6 +59,8 @@ class StridedSliceGpuKernelMod : public NativeGpuKernelMod, public StridedSliceG
|
|||
StridedSliceFunc kernel_func_;
|
||||
|
||||
bool is_null_input_{false};
|
||||
bool is_dynamic_attr_{false};
|
||||
bool get_dynamic_attr_value_{false};
|
||||
static constexpr size_t kBeginIndex_{1};
|
||||
static constexpr size_t kEndIndex_{2};
|
||||
static constexpr size_t kStrideIndex_{3};
|
||||
|
|
|
@ -77,6 +77,7 @@ RER_GPU_STATIC_CONST_TO_ATTR(kSpaceToBatchOpName, 1);
|
|||
RER_GPU_STATIC_CONST_TO_ATTR(kSparseApplyAdagradOpName, 2);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kSparseGatherV2OpName, 2);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kStridedSliceAssignOpName, 1, 2, 3);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kStridedSliceOpName, 1, 2, 3);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kSubscalarOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kTensorCopySlicesOpName, 2, 3, 4);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kTileOpName, 1);
|
||||
|
|
Loading…
Reference in New Issue