!39839 [MS][OPS] fix unsortedsegmentsum ops numsegment dyn support int32 issue

Merge pull request !39839 from luoyuan/fix-unsortedsegmentsum-numsegments-dyn
This commit is contained in:
i-robot 2022-08-08 07:15:20 +00:00 committed by Gitee
commit 1da193fc38
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 37 additions and 4 deletions

View File

@ -227,7 +227,27 @@ const std::vector<std::pair<KernelAttr, KernelRunFunc>> &UnsortedSegmentArithmet
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt32, kNumberTypeInt32, kNumberTypeInt32, uint32_t, int)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt32, kNumberTypeInt64, kNumberTypeInt32, uint32_t, int64_t)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt64, kNumberTypeInt32, kNumberTypeInt32, uint64_t, int)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt64, kNumberTypeInt64, kNumberTypeInt32, uint64_t, int64_t)}};
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt64, kNumberTypeInt64, kNumberTypeInt32, uint64_t, int64_t)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, kNumberTypeInt64, double, int)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, kNumberTypeInt64, double, int64_t)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, kNumberTypeInt64, float, int)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, kNumberTypeInt64, float, int64_t)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeInt32, kNumberTypeInt32, kNumberTypeInt64, float, int)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeInt32, kNumberTypeInt64, kNumberTypeInt64, float, int64_t)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeInt64, uint8_t, int)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt8, kNumberTypeInt64, kNumberTypeInt64, uint8_t, int64_t)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, int16_t, int)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeInt16, kNumberTypeInt64, kNumberTypeInt64, int16_t, int64_t)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeInt8, kNumberTypeInt32, kNumberTypeInt64, int8_t, int)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeInt8, kNumberTypeInt64, kNumberTypeInt64, int8_t, int64_t)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeInt64, kNumberTypeInt32, kNumberTypeInt64, int64_t, int)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeInt64, kNumberTypeInt64, kNumberTypeInt64, int64_t, int64_t)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt16, kNumberTypeInt32, kNumberTypeInt64, uint16_t, int)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt16, kNumberTypeInt64, kNumberTypeInt64, uint16_t, int64_t)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt32, kNumberTypeInt32, kNumberTypeInt64, uint32_t, int)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt32, kNumberTypeInt64, kNumberTypeInt64, uint32_t, int64_t)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt64, kNumberTypeInt32, kNumberTypeInt64, uint64_t, int)},
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt64, kNumberTypeInt64, kNumberTypeInt64, uint64_t, int64_t)}};
return func_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, UnsortedSegmentMin, UnsortedSegmentArithmeticCpuKernelMod);

View File

@ -41,7 +41,8 @@ void GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector<Abstra
if (n_value_ptr->isa<tensor::Tensor>()) {
auto n_tensor_ptr = n_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(n_tensor_ptr);
num_segments_v = *static_cast<int64_t *>(n_tensor_ptr->data_c());
num_segments_v = n_tensor_ptr->data_type() == kNumberTypeInt32 ? *static_cast<int32_t *>(n_tensor_ptr->data_c())
: *static_cast<int64_t *>(n_tensor_ptr->data_c());
(void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name);
num_vec->push_back(num_segments_v);
num_min_vec->push_back(num_segments_v);
@ -58,7 +59,19 @@ void GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector<Abstra
}
}
} else if (input_args[kInputIndex2]->isa<abstract::AbstractScalar>()) {
num_segments_v = GetValue<int64_t>(input_args[kInputIndex2]->BuildValue());
auto num_segments_input_type = input_args[kInputIndex2]->BuildType();
if (num_segments_input_type->type_id() == kNumberTypeInt64) {
auto num_sample_ptr = input_args[kInputIndex2]->cast<abstract::AbstractScalarPtr>();
MS_EXCEPTION_IF_NULL(num_sample_ptr);
num_segments_v = GetValue<int64_t>(input_args[kInputIndex2]->BuildValue());
} else if (num_segments_input_type->type_id() == kNumberTypeInt32) {
auto num_sample_ptr = input_args[kInputIndex2]->cast<abstract::AbstractScalarPtr>();
MS_EXCEPTION_IF_NULL(num_sample_ptr);
num_segments_v = GetValue<int32_t>(input_args[kInputIndex2]->BuildValue());
} else {
MS_EXCEPTION(TypeError) << "For '" << op_name << "' the third input build type is invalid:"
<< TypeIdToString(num_segments_input_type->type_id()) << ".";
}
(void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name);
num_vec->push_back(num_segments_v);
num_min_vec->push_back(num_segments_v);

View File

@ -235,7 +235,7 @@ def test_dynamic_getitem_tensor():
fact.grad_impl()
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training