forked from mindspore-Ecosystem/mindspore
!39839 [MS][OPS] fix unsortedsegmentsum ops numsegment dyn support int32 issue
Merge pull request !39839 from luoyuan/fix-unsortedsegmentsum-numsegments-dyn
This commit is contained in:
commit
1da193fc38
|
@ -227,7 +227,27 @@ const std::vector<std::pair<KernelAttr, KernelRunFunc>> &UnsortedSegmentArithmet
|
|||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt32, kNumberTypeInt32, kNumberTypeInt32, uint32_t, int)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt32, kNumberTypeInt64, kNumberTypeInt32, uint32_t, int64_t)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt64, kNumberTypeInt32, kNumberTypeInt32, uint64_t, int)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt64, kNumberTypeInt64, kNumberTypeInt32, uint64_t, int64_t)}};
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt64, kNumberTypeInt64, kNumberTypeInt32, uint64_t, int64_t)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, kNumberTypeInt64, double, int)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, kNumberTypeInt64, double, int64_t)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, kNumberTypeInt64, float, int)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, kNumberTypeInt64, float, int64_t)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeInt32, kNumberTypeInt32, kNumberTypeInt64, float, int)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeInt32, kNumberTypeInt64, kNumberTypeInt64, float, int64_t)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeInt64, uint8_t, int)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt8, kNumberTypeInt64, kNumberTypeInt64, uint8_t, int64_t)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, int16_t, int)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeInt16, kNumberTypeInt64, kNumberTypeInt64, int16_t, int64_t)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeInt8, kNumberTypeInt32, kNumberTypeInt64, int8_t, int)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeInt8, kNumberTypeInt64, kNumberTypeInt64, int8_t, int64_t)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeInt64, kNumberTypeInt32, kNumberTypeInt64, int64_t, int)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeInt64, kNumberTypeInt64, kNumberTypeInt64, int64_t, int64_t)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt16, kNumberTypeInt32, kNumberTypeInt64, uint16_t, int)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt16, kNumberTypeInt64, kNumberTypeInt64, uint16_t, int64_t)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt32, kNumberTypeInt32, kNumberTypeInt64, uint32_t, int)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt32, kNumberTypeInt64, kNumberTypeInt64, uint32_t, int64_t)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt64, kNumberTypeInt32, kNumberTypeInt64, uint64_t, int)},
|
||||
{UNSORTED_SEGMENT_ARITH_CPU_DY_REGISTER(kNumberTypeUInt64, kNumberTypeInt64, kNumberTypeInt64, uint64_t, int64_t)}};
|
||||
return func_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, UnsortedSegmentMin, UnsortedSegmentArithmeticCpuKernelMod);
|
||||
|
|
|
@ -41,7 +41,8 @@ void GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector<Abstra
|
|||
if (n_value_ptr->isa<tensor::Tensor>()) {
|
||||
auto n_tensor_ptr = n_value_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(n_tensor_ptr);
|
||||
num_segments_v = *static_cast<int64_t *>(n_tensor_ptr->data_c());
|
||||
num_segments_v = n_tensor_ptr->data_type() == kNumberTypeInt32 ? *static_cast<int32_t *>(n_tensor_ptr->data_c())
|
||||
: *static_cast<int64_t *>(n_tensor_ptr->data_c());
|
||||
(void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name);
|
||||
num_vec->push_back(num_segments_v);
|
||||
num_min_vec->push_back(num_segments_v);
|
||||
|
@ -58,7 +59,19 @@ void GetNumSegmentsValue(const PrimitivePtr &primitive, const std::vector<Abstra
|
|||
}
|
||||
}
|
||||
} else if (input_args[kInputIndex2]->isa<abstract::AbstractScalar>()) {
|
||||
num_segments_v = GetValue<int64_t>(input_args[kInputIndex2]->BuildValue());
|
||||
auto num_segments_input_type = input_args[kInputIndex2]->BuildType();
|
||||
if (num_segments_input_type->type_id() == kNumberTypeInt64) {
|
||||
auto num_sample_ptr = input_args[kInputIndex2]->cast<abstract::AbstractScalarPtr>();
|
||||
MS_EXCEPTION_IF_NULL(num_sample_ptr);
|
||||
num_segments_v = GetValue<int64_t>(input_args[kInputIndex2]->BuildValue());
|
||||
} else if (num_segments_input_type->type_id() == kNumberTypeInt32) {
|
||||
auto num_sample_ptr = input_args[kInputIndex2]->cast<abstract::AbstractScalarPtr>();
|
||||
MS_EXCEPTION_IF_NULL(num_sample_ptr);
|
||||
num_segments_v = GetValue<int32_t>(input_args[kInputIndex2]->BuildValue());
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For '" << op_name << "' the third input build type is invalid:"
|
||||
<< TypeIdToString(num_segments_input_type->type_id()) << ".";
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name);
|
||||
num_vec->push_back(num_segments_v);
|
||||
num_min_vec->push_back(num_segments_v);
|
||||
|
|
|
@ -235,7 +235,7 @@ def test_dynamic_getitem_tensor():
|
|||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
|
Loading…
Reference in New Issue