diff --git a/docs/api/api_python/ops/mindspore.ops.ScatterAdd.rst b/docs/api/api_python/ops/mindspore.ops.ScatterAdd.rst index a72d2a1b10f..da501577a2d 100644 --- a/docs/api/api_python/ops/mindspore.ops.ScatterAdd.rst +++ b/docs/api/api_python/ops/mindspore.ops.ScatterAdd.rst @@ -22,7 +22,7 @@ mindspore.ops.ScatterAdd **输入:** - **input_x** (Parameter) - ScatterAdd的输入,任意维度的Parameter。 - - **indices** (Tensor) - 指定相加操作的索引,数据类型为mindspore.int32。 + - **indices** (Tensor) - 指定相加操作的索引,数据类型为mindspore.int32或者mindspore.int64。 - **updates** (Tensor) - 指定与 `input_x` 相加操作的Tensor,数据类型与 `input_x` 相同,shape为 `indices.shape + x.shape[1:]` 。 **输出:** @@ -32,6 +32,6 @@ mindspore.ops.ScatterAdd **异常:** - **TypeError** - `use_locking` 不是bool。 - - **TypeError** - `indices` 不是int32。 + - **TypeError** - `indices` 不是int32或者int64。 - **ValueError** - `updates` 的shape不等于 `indices.shape + x.shape[1:]` 。 - **RuntimeError** - 当 `input_x` 和 `updates` 类型不一致,需要进行类型转换时,如果 `updates` 不支持转成参数 `input_x` 需要的数据类型,就会报错。 diff --git a/docs/api/api_python/ops/mindspore.ops.ScatterMax.rst b/docs/api/api_python/ops/mindspore.ops.ScatterMax.rst index c5b6561be3e..d001a7778aa 100644 --- a/docs/api/api_python/ops/mindspore.ops.ScatterMax.rst +++ b/docs/api/api_python/ops/mindspore.ops.ScatterMax.rst @@ -20,7 +20,7 @@ mindspore.ops.ScatterMax **输入:** - **input_x** (Parameter)- ScatterMax的输入,任意维度的Parameter。 - - **indices** (Tensor) - 指定最大值操作的索引,数据类型必须为mindspore.int32。 + - **indices** (Tensor) - 指定最大值操作的索引,数据类型必须为mindspore.int32或者mindspore.int64。 - **updates** (Tensor) - 指定与 `input_x` 取最大值操作的Tensor,数据类型与 `input_x` 相同,shape为 `indices.shape + x.shape[1:]` 。 @@ -31,6 +31,6 @@ mindspore.ops.ScatterMax **异常:** - **TypeError** - `use_locking` 不是bool。 - - **TypeError** - `indices` 不是int32。 + - **TypeError** - `indices` 不是int32或者int64。 - **ValueError** - `updates` 的shape不等于 `indices.shape + x.shape[1:]` 。 - **RuntimeError** - 当 `input_x` 和 `updates` 类型不一致,需要进行类型转换时,如果 `updates` 不支持转成参数 `input_x` 需要的数据类型,就会报错。 \ No newline at end of file diff --git a/docs/api/api_python/ops/mindspore.ops.ScatterMin.rst b/docs/api/api_python/ops/mindspore.ops.ScatterMin.rst index b2e4ffe2114..172f3960829 100644 --- a/docs/api/api_python/ops/mindspore.ops.ScatterMin.rst +++ b/docs/api/api_python/ops/mindspore.ops.ScatterMin.rst @@ -11,7 +11,7 @@ mindspore.ops.ScatterMin \text{input_x}[\text{indices}[i, ..., j], :] = min(\text{input_x}[\text{indices}[i, ..., j], :], \text{updates}[i, ..., j, :]) - 输入的 `input_x` 和 `updates` 遵循隐式类型转换规则,以确保数据类型一致。如果数据类型不同,则低精度数据类型将转换为高精度的数据类型。当参数的数据类型需要转换时,则会抛出RuntimeError异常。 + 输入的 `input_x` 和 `updates` 遵循隐式类型转换规则,以确保数据类型一致。如果数据类型不同,则低精度数据类型将转换为高精度的数据类型。当`updates` 不支持转成 `input_x` 需要的数据类型时,则会抛出RuntimeError异常。 **参数:** @@ -20,7 +20,7 @@ mindspore.ops.ScatterMin **输入:** - **input_x** (Parameter) - ScatterMin的输入,任意维度的Parameter。 - - **indices** (Tensor) - 指定最小值操作的索引,数据类型必须为mindspore.int32。 + - **indices** (Tensor) - 指定最小值操作的索引,数据类型必须为mindspore.int32或者mindspore.int64。 - **updates** (Tensor) - 指定与 `input_x` 取最小值操作的Tensor,数据类型与 `input_x` 相同,shape为 `indices.shape + x.shape[1:]` 。 **输出:** @@ -30,6 +30,6 @@ mindspore.ops.ScatterMin **异常:** - **TypeError** - `use_locking` 不是bool。 - - **TypeError** - `indices` 不是int32。 + - **TypeError** - `indices` 不是int32或者int64。 - **ValueError** - `updates` 的shape不等于 `indices.shape + x.shape[1:]` 。 - **RuntimeError** - 当 `input_x` 和 `updates` 类型不一致,需要进行类型转换时,如果 `updates` 不支持转成参数 `input_x` 需要的数据类型,就会报错。 \ No newline at end of file diff --git a/docs/api/api_python/ops/mindspore.ops.ScatterSub.rst b/docs/api/api_python/ops/mindspore.ops.ScatterSub.rst index f3bcc1f60f1..5572cecdfea 100644 --- a/docs/api/api_python/ops/mindspore.ops.ScatterSub.rst +++ b/docs/api/api_python/ops/mindspore.ops.ScatterSub.rst @@ -19,7 +19,7 @@ **输入:** - **input_x** (Parameter) - ScatterSub的输入,任意维度的Parameter。 - - **indices** (Tensor) - 指定相减操作的索引,其数据类型必须为mindspore.int32。 + - **indices** (Tensor) - 指定相减操作的索引,其数据类型必须为mindspore.int32或者mindspore.int64。 - **updates** (Tensor) - 指定与 `input_x` 相减的Tensor,其数据类型与 `input_x` 的相同,shape为 `indices_shape + x_shape[1:]` 。 **输出:** @@ -29,6 +29,6 @@ **异常:** - **TypeError** - `use_locking` 不是bool。 - - **TypeError** - `indices` 不是int32。 + - **TypeError** - `indices` 不是int32或者int64。 - **ValueError** - `updates` 的shape不是 `indices_shape + x_shape[1:]` 。 - **RuntimeError** - 当 `input_x` 和 `updates` 类型不一致,需要进行类型转换时,如果 `updates` 不支持转成参数 `input_x` 需要的数据类型,就会报错。 \ No newline at end of file diff --git a/docs/api/api_python/ops/mindspore.ops.ScatterUpdate.rst b/docs/api/api_python/ops/mindspore.ops.ScatterUpdate.rst index c59df2c8865..66085dff606 100644 --- a/docs/api/api_python/ops/mindspore.ops.ScatterUpdate.rst +++ b/docs/api/api_python/ops/mindspore.ops.ScatterUpdate.rst @@ -19,7 +19,7 @@ **输入:** - **input_x** (Parameter) - ScatterUpdate的输入,任意维度的Parameter。 - - **indices** (Tensor) - 指定更新操作的索引。数据类型为int32。如果索引中存在重复项,则更新的顺序无法得知。 + - **indices** (Tensor) - 指定更新操作的索引。数据类型为int32或者int64。如果索引中存在重复项,则更新的顺序无法得知。 - **updates** (Tensor) - 指定与 `input_x` 更新操作的Tensor,其数据类型与 `input_x` 相同,shape为 `indices.shape + input_x.shape[1:]` 。 **输出:** @@ -29,5 +29,5 @@ **异常:** - **TypeError** - `use_locking` 不是bool。 - - **TypeError** - `indices` 不是int32。 + - **TypeError** - `indices` 不是int32或者int64。 - **RuntimeError** - 当 `input_x` 和 `updates` 类型不一致,需要进行类型转换时,如果 `updates` 不支持转成参数 `input_x` 需要的数据类型,就会报错。 \ No newline at end of file diff --git a/docs/api/api_python/ops/mindspore.ops.func_scatter_max.rst b/docs/api/api_python/ops/mindspore.ops.func_scatter_max.rst index 13704d1ef0f..aac7c1fe754 100644 --- a/docs/api/api_python/ops/mindspore.ops.func_scatter_max.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_scatter_max.rst @@ -16,7 +16,7 @@ mindspore.ops.scatter_max **参数:** - **input_x** (Parameter)- ScatterMax的输入,任意维度的Parameter。 - - **indices** (Tensor) - 指定最大值操作的索引,数据类型必须为mindspore.int32。 + - **indices** (Tensor) - 指定最大值操作的索引,数据类型必须为mindspore.int32或者mindspore.int64。 - **updates** (Tensor) - 指定与 `input_x` 取最大值操作的Tensor,数据类型与 `input_x` 相同,shape为 `indices.shape + x.shape[1:]` 。 **输出:** @@ -26,6 +26,6 @@ mindspore.ops.scatter_max **异常:** - **TypeError** - `use_locking` 不是bool。 - - **TypeError** - `indices` 不是int32。 + - **TypeError** - `indices` 不是int32或者int64。 - **ValueError** - `updates` 的shape不等于 `indices.shape + x.shape[1:]` 。 - **RuntimeError** - 当 `input_x` 和 `updates` 类型不一致,需要进行类型转换时,如果 `updates` 不支持转成参数 `input_x` 需要的数据类型,就会报错。 \ No newline at end of file diff --git a/docs/api/api_python/ops/mindspore.ops.func_scatter_min.rst b/docs/api/api_python/ops/mindspore.ops.func_scatter_min.rst index 276b8a953d7..e81e4d07855 100644 --- a/docs/api/api_python/ops/mindspore.ops.func_scatter_min.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_scatter_min.rst @@ -16,7 +16,7 @@ mindspore.ops.scatter_min **输入:** - **input_x** (Parameter) - scatter_min的输入,任意维度的Parameter。 - - **indices** (Tensor) - 指定最小值操作的索引,数据类型必须为mindspore.int32。 + - **indices** (Tensor) - 指定最小值操作的索引,数据类型必须为mindspore.int32或者mindspore.int64。 - **updates** (Tensor) - 指定与 `input_x` 取最小值操作的Tensor,数据类型与 `input_x` 相同,shape为 `indices.shape + input_x.shape[1:]` 。 **输出:** @@ -25,6 +25,6 @@ mindspore.ops.scatter_min **异常:** - - **TypeError** - `indices` 不是int32。 + - **TypeError** - `indices` 不是int32或者int64。 - **ValueError** - `updates` 的shape不等于 `indices.shape + input_x.shape[1:]` 。 - **RuntimeError** - 当 `input_x` 和 `updates` 类型不一致,需要进行类型转换时,如果 `updates` 不支持转成参数 `input_x` 需要的数据类型,就会报错。 \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_functor_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_functor_gpu_kernel.cc index 0da85534fe9..ec6ff90fadc 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_functor_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_functor_gpu_kernel.cc @@ -307,6 +307,20 @@ MS_REG_GPU_KERNEL_TWO(ScatterMax, ScatterFunctorKernelMod, uint8_t, int64_t) // ScatterMin +MS_REG_GPU_KERNEL_TWO(ScatterMin, + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + ScatterFunctorKernelMod, double, int) +MS_REG_GPU_KERNEL_TWO(ScatterMin, + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + ScatterFunctorKernelMod, double, int64_t) MS_REG_GPU_KERNEL_TWO(ScatterMin, KernelAttr() .AddInputAttr(kNumberTypeFloat32) @@ -335,6 +349,20 @@ MS_REG_GPU_KERNEL_TWO(ScatterMin, .AddInputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16), ScatterFunctorKernelMod, half, int64_t) +MS_REG_GPU_KERNEL_TWO(ScatterMin, + KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + ScatterFunctorKernelMod, int64_t, int) +MS_REG_GPU_KERNEL_TWO(ScatterMin, + KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + ScatterFunctorKernelMod, int64_t, int64_t) MS_REG_GPU_KERNEL_TWO(ScatterMin, KernelAttr() .AddInputAttr(kNumberTypeInt32) @@ -349,33 +377,5 @@ MS_REG_GPU_KERNEL_TWO(ScatterMin, .AddInputAttr(kNumberTypeInt32) .AddOutputAttr(kNumberTypeInt32), ScatterFunctorKernelMod, int, int64_t) -MS_REG_GPU_KERNEL_TWO(ScatterMin, - KernelAttr() - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt8) - .AddOutputAttr(kNumberTypeInt8), - ScatterFunctorKernelMod, int8_t, int) -MS_REG_GPU_KERNEL_TWO(ScatterMin, - KernelAttr() - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt8) - .AddOutputAttr(kNumberTypeInt8), - ScatterFunctorKernelMod, int8_t, int64_t) -MS_REG_GPU_KERNEL_TWO(ScatterMin, - KernelAttr() - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeUInt8) - .AddOutputAttr(kNumberTypeUInt8), - ScatterFunctorKernelMod, uint8_t, int) -MS_REG_GPU_KERNEL_TWO(ScatterMin, - KernelAttr() - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeUInt8) - .AddOutputAttr(kNumberTypeUInt8), - ScatterFunctorKernelMod, uint8_t, int64_t) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_functor_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_functor_impl.cu index 723fa6354b1..3be979ccefe 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_functor_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_functor_impl.cu @@ -128,6 +128,14 @@ template CUDA_LIB_EXPORT void ScatterFunc(enum ScatterFunctorType const size_t &inner_size, const size_t &indices_size, const int64_t *indices, const half *updates, half *input, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void ScatterFunc(enum ScatterFunctorType func_type, int size_limit, + const size_t &inner_size, const size_t &indices_size, + const int *indices, const double *updates, double *input, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void ScatterFunc(enum ScatterFunctorType func_type, int64_t size_limit, + const size_t &inner_size, const size_t &indices_size, + const int64_t *indices, const double *updates, double *input, + cudaStream_t cuda_stream); template CUDA_LIB_EXPORT void ScatterFunc(enum ScatterFunctorType func_type, int size_limit, const size_t &inner_size, const size_t &indices_size, const int *indices, const int *updates, int *input, @@ -136,6 +144,14 @@ template CUDA_LIB_EXPORT void ScatterFunc(enum ScatterFunctorType const size_t &inner_size, const size_t &indices_size, const int64_t *indices, const int *updates, int *input, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void ScatterFunc(enum ScatterFunctorType func_type, int size_limit, + const size_t &inner_size, const size_t &indices_size, + const int *indices, const int64_t *updates, int64_t *input, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void ScatterFunc(enum ScatterFunctorType func_type, int64_t size_limit, + const size_t &inner_size, const size_t &indices_size, + const int64_t *indices, const int64_t *updates, + int64_t *input, cudaStream_t cuda_stream); template CUDA_LIB_EXPORT void ScatterFunc(enum ScatterFunctorType func_type, int size_limit, const size_t &inner_size, const size_t &indices_size, const int *indices, const unsigned char *updates, diff --git a/mindspore/core/ops/scatter_min.cc b/mindspore/core/ops/scatter_min.cc index 32955ead2c1..4d9a3fd022c 100644 --- a/mindspore/core/ops/scatter_min.cc +++ b/mindspore/core/ops/scatter_min.cc @@ -40,7 +40,7 @@ abstract::ShapePtr ScatterMinInferShape(const PrimitivePtr &primitive, const std } if (indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) { - return input_args[kInputIndex0]->BuildShape()->cast(); + return input_x_shape_ptr->cast(); } std::vector input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape]; @@ -66,12 +66,11 @@ TypePtr ScatterMinInferType(const PrimitivePtr &primitive, const std::vectorBuildType(); auto updates_type_ptr = input_args[kInputIndex2]->BuildType(); auto prim_name = primitive->name(); - std::set type_set = {kInt32}; - (void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indiecs_type_ptr, type_set, prim_name); - (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x type", input_x_type_ptr, common_valid_types_with_complex, - prim_name); - (void)CheckAndConvertUtils::CheckTensorTypeValid("updates type", updates_type_ptr, common_valid_types_with_complex, - prim_name); + const std::set indices_types = {kInt32, kInt64}; + const std::set valid_types = {kInt32, kInt64, kFloat16, kFloat32, kFloat64}; + (void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indiecs_type_ptr, indices_types, prim_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x type", input_x_type_ptr, valid_types, prim_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("updates type", updates_type_ptr, valid_types, prim_name); std::map type_dict; type_dict.emplace("input_x", input_x_type_ptr); diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/scatter_min.py b/mindspore/python/mindspore/ops/_op_impl/tbe/scatter_min.py index a4ab87a0d79..b8fbe126ac7 100644 --- a/mindspore/python/mindspore/ops/_op_impl/tbe/scatter_min.py +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/scatter_min.py @@ -23,14 +23,19 @@ scatter_min_op_info = TBERegOp("ScatterMin") \ .compute_cost(10) \ .kernel_name("scatter_min") \ .partial_flag(True) \ + .dynamic_compile_static(True) \ + .dynamic_shape(True) \ .attr("use_locking", "optional", "bool", "all") \ .input(0, "var", False, "required", "all") \ .input(1, "indices", False, "required", "all") \ .input(2, "updates", False, "required", "all") \ .output(0, "var", False, "required", "all") \ .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \ .get_op_info() diff --git a/mindspore/python/mindspore/ops/function/array_func.py b/mindspore/python/mindspore/ops/function/array_func.py index 0524382b224..328955f3259 100644 --- a/mindspore/python/mindspore/ops/function/array_func.py +++ b/mindspore/python/mindspore/ops/function/array_func.py @@ -769,7 +769,7 @@ def scatter_min(input_x, indices, updates): Using given values to update tensor value through the min operation, along with the input indices. This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value. - for each `i, ..., j` in `indices.shape`: + for each :math:`i, ..., j` in `indices.shape`: .. math:: @@ -781,21 +781,20 @@ def scatter_min(input_x, indices, updates): the relatively highest priority data type. Args: - - **input_x** (Parameter) - The target tensor, with data type of Parameter. - The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions. - - **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32. - - **updates** (Tensor) - The tensor doing the min operation with `input_x`, - the data type is same as `input_x`, the shape is `indices.shape + x.shape[1:]`. + input_x (Parameter): The target tensor, with data type of Parameter. + The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions. + indices (Tensor): The index to do min operation whose data type must be mindspore.int32 or mindspore.int64. + updates (Tensor): The tensor doing the min operation with `input_x`, + the data type is same as `input_x`, the shape is `indices.shape + input_x.shape[1:]`. Outputs: Tensor, the updated `input_x`, has the same shape and type as `input_x`. Raises: - TypeError: If `indices` is not an int32. - TypeError: If `use_locking` is not a bool. + TypeError: If `indices` is not an int32 or an int64. RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter is required when data type conversion of Parameter is not supported. - ValueError: If the shape of `updates` is not equal to `indices.shape + x.shape[1:]`. + ValueError: If the shape of `updates` is not equal to `indices.shape + input_x.shape[1:]`. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index 056c6d7523b..d05ab5ee20f 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -109,7 +109,7 @@ class _ScatterOpDynamic(PrimitiveWithCheck): self._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) def check_dtype(self, x_dtype, indices_dtype, updates_dtype): - validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) + validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32, mstype.int64], self.name) args = {"x": x_dtype, "updates": updates_dtype} validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) @@ -4160,7 +4160,7 @@ class ScatterUpdate(_ScatterOpDynamic): Inputs: - **input_x** (Parameter) - The target tensor, with data type of Parameter. The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions. - - **indices** (Tensor) - The index of input tensor. With int32 data type. + - **indices** (Tensor) - The index of input tensor. With int32 or int64 data type. If there are duplicates in indices, the order for updating is undefined. - **updates** (Tensor) - The tensor to update the input tensor, has the same type as input, and updates.shape = indices.shape + input_x.shape[1:]. @@ -4170,7 +4170,7 @@ class ScatterUpdate(_ScatterOpDynamic): Raises: TypeError: If `use_locking` is not a bool. - TypeError: If `indices` is not an int32. + TypeError: If `indices` is not an int32 or an int64. RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter is required when data type conversion of Parameter is not supported. @@ -4289,7 +4289,8 @@ class ScatterMax(_ScatterOpDynamic): Inputs: - **input_x** (Parameter) - The target tensor, with data type of Parameter. The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions. - - **indices** (Tensor) - The index to do max operation whose data type must be mindspore.int32. + - **indices** (Tensor) - The index to do max operation whose data type must be mindspore.int32 or + mindspore.int64. - **updates** (Tensor) - The tensor that performs the maximum operation with `input_x`, the data type is the same as `input_x`, the shape is `indices.shape + x.shape[1:]`. @@ -4298,7 +4299,7 @@ class ScatterMax(_ScatterOpDynamic): Raises: TypeError: If `use_locking` is not a bool. - TypeError: If `indices` is not an int32. + TypeError: If `indices` is not an int32 or an int64. ValueError: If the shape of `updates` is not equal to `indices.shape + x.shape[1:]`. RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter is required when data type conversion of Parameter is not supported. @@ -4326,7 +4327,7 @@ class ScatterMin(_ScatterOpDynamic): Using given values to update tensor value through the min operation, along with the input indices. This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value. - for each `i, ..., j` in `indices.shape`: + for each :math:`i, ..., j` in `indices.shape`: .. math:: @@ -4343,17 +4344,18 @@ class ScatterMin(_ScatterOpDynamic): Inputs: - **input_x** (Parameter) - The target tensor, with data type of Parameter. The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions. - - **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32. + - **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32 or + mindspore.int64. - **updates** (Tensor) - The tensor doing the min operation with `input_x`, - the data type is same as `input_x`, the shape is `indices.shape + x.shape[1:]`. + the data type is same as `input_x`, the shape is `indices.shape + input_x.shape[1:]`. Outputs: Tensor, the updated `input_x`, has the same shape and type as `input_x`. Raises: TypeError: If `use_locking` is not a bool. - TypeError: If `indices` is not an int32. - ValueError: If the shape of `updates` is not equal to `indices.shape + x.shape[1:]`. + TypeError: If `indices` is not an int32 or an int64. + ValueError: If the shape of `updates` is not equal to `indices.shape + input_x.shape[1:]`. RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter is required when data type conversion of Parameter is not supported. @@ -4401,7 +4403,8 @@ class ScatterAdd(_ScatterOpDynamic): Inputs: - **input_x** (Parameter) - The target tensor, with data type of Parameter. The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions. - - **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32. + - **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32 or + mindspore.int64. - **updates** (Tensor) - The tensor doing the min operation with `input_x`, the data type is same as `input_x`, the shape is `indices.shape + x.shape[1:]`. @@ -4410,7 +4413,7 @@ class ScatterAdd(_ScatterOpDynamic): Raises: TypeError: If `use_locking` is not a bool. - TypeError: If `indices` is not an int32. + TypeError: If `indices` is not an int32 or an int64. ValueError: If the shape of `updates` is not equal to `indices.shape + x.shape[1:]`. RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter is required when data type conversion of Parameter is not supported. @@ -4511,7 +4514,8 @@ class ScatterSub(_ScatterOpDynamic): Inputs: - **input_x** (Parameter) - The target tensor, with data type of Parameter. The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions. - - **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32. + - **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32 or + mindspore.int64. - **updates** (Tensor) - The tensor doing the min operation with `input_x`, the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`. @@ -4520,7 +4524,7 @@ class ScatterSub(_ScatterOpDynamic): Raises: TypeError: If `use_locking` is not a bool. - TypeError: If `indices` is not an int32. + TypeError: If `indices` is not an int32 or an int64. ValueError: If the shape of `updates` is not equal to `indices_shape + x_shape[1:]`. RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter is required when data type conversion of Parameter is not supported. diff --git a/tests/st/ops/ascend/test_scatter_min.py b/tests/st/ops/ascend/test_scatter_min.py new file mode 100644 index 00000000000..8b470dae23b --- /dev/null +++ b/tests/st/ops/ascend/test_scatter_min.py @@ -0,0 +1,250 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor, Parameter, ParameterTuple + +# all cases tested against dchip + + +class TestScatterMinNet(nn.Cell): + def __init__(self, inputx): + super(TestScatterMinNet, self).__init__() + + self.scatter_min = ops.ScatterMin() + self.inputx = Parameter(inputx, name="inputx") + + def construct(self, indices, updates): + out = self.scatter_min(self.inputx, indices, updates) + return out + + +def scatter_min_forward(nptype): + inputx = Tensor(np.arange(0, 9).reshape((3, 3)).astype(nptype)) + indices = Tensor(np.array([[[1, 0, 2], [2, 2, 0]], [[1, 0, 1], [2, 1, 2]]]).astype(np.int32)) + updates = Tensor(np.arange(34, 70).reshape((2, 2, 3, 3)).astype(nptype)) + + net = TestScatterMinNet(inputx) + output = net(indices, updates) + expected = inputx.asnumpy() + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +def scatter_min_dynamic_updates(): + inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.float32)) + indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32)) + updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.float32)) + updates_dy = Tensor(shape=(2, 2, 2, None, 4), dtype=mindspore.float32) + + net = TestScatterMinNet(inputx) + net.set_inputs(indices, updates_dy) + output = net(indices, updates) + expected = np.ones((4, 2, 3, 4)).astype(np.float32) + expected[0][0][0][0] = 0.0 + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +def scatter_min_dynamic_indices(): + inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.int32)) + indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32)) + indices_dy = Tensor(shape=(2, None), dtype=mindspore.int32) + updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.int32)) + + net = TestScatterMinNet(inputx) + net.set_inputs(indices_dy, updates) + output = net(indices, updates) + expected = np.ones((4, 2, 3, 4)).astype(np.int32) + expected[0][0][0][0] = 0 + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +class TestScatterMinGradNet(nn.Cell): + def __init__(self, network): + super(TestScatterMinGradNet, self).__init__() + self.grad = ops.GradOperation(get_all=True, sens_param=True, get_by_list=True) + self.network = network + self.params = ParameterTuple(network.trainable_params()) + + def construct(self, indices, updates, dout): + out = self.grad(self.network, self.params)(indices, updates, dout) + return out + + +def scatter_min_grad(nptype): + inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(nptype))) + indices = Tensor(np.array([[[0, 1, 2], [2, 1, 0]], [[0, 0, 0], [2, 2, 2]]]).astype(np.int32)) + updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(nptype)) + dout = Tensor(np.flip(np.arange(0, 12).reshape((3, 4)).astype(nptype))) + + net = TestScatterMinGradNet(TestScatterMinNet(inputx)) + output = net(indices, updates, dout) + indices_grad = output[0][0] + updates_grad = output[0][1] + + indices_expected = np.array([[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]).astype(nptype) + updates_expected = np.array( + [ + [ + [ + [11, 10, 9, 8], [7, 6, 5, 4], [3, 2, 1, 0] + ], + [ + [3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8] + ] + ], + [ + [ + [11, 10, 9, 8], [11, 10, 9, 8], [11, 10, 9, 8] + ], + [ + [3, 2, 1, 0], [3, 2, 1, 0], [3, 2, 1, 0] + ] + ] + ]).astype(nptype) + np.testing.assert_array_almost_equal(indices_grad, indices_expected) + np.testing.assert_array_almost_equal(updates_grad, updates_expected) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_scatter_min_forward_float16(): + """ + Feature: test scatter_min forward. + Description: test float16 inputs. + Expectation: the result match with numpy result + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + scatter_min_forward(np.float16) + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + scatter_min_forward(np.float16) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_scatter_min_forward_float32(): + """ + Feature: test scatter_min forward. + Description: test float32 inputs. + Expectation: the result match with numpy result + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + scatter_min_forward(np.float32) + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + scatter_min_forward(np.float32) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_scatter_min_forward_int32(): + """ + Feature: test scatter_min forward. + Description: test int32 inputs. + Expectation: the result match with numpy result + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + scatter_min_forward(np.int32) + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + scatter_min_forward(np.int32) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_scatter_min_dynamic_indices(): + """ + Feature: test scatter_min dynamic shape. + Description: indices is dynamic shape. + Expectation: the result match with numpy result + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + scatter_min_dynamic_indices() + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + scatter_min_dynamic_indices() + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_scatter_min_dynamic_updates(): + """ + Feature: test scatter_min dynamic shape. + Description: updates is dynamic shape. + Expectation: the result match with numpy result + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + scatter_min_dynamic_updates() + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + scatter_min_dynamic_updates() + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_scatter_min_grad_float16(): + """ + Feature: test scatter_min grad. + Description: test float16 inputs. + Expectation: the result match with numpy result + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + scatter_min_grad(np.float16) + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + scatter_min_grad(np.float16) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_scatter_min_grad_float32(): + """ + Feature: test scatter_min grad. + Description: test float32 inputs. + Expectation: the result match with numpy result + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + scatter_min_grad(np.float32) + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + scatter_min_grad(np.float32) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_scatter_min_grad_int32(): + """ + Feature: test scatter_min grad. + Description: test int32 inputs. + Expectation: the result match with numpy result + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + scatter_min_grad(np.int32) + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + scatter_min_grad(np.int32) diff --git a/tests/st/ops/gpu/test_scatter_func_op.py b/tests/st/ops/gpu/test_scatter_func_op.py index d05ee45e606..4ee0537b586 100644 --- a/tests/st/ops/gpu/test_scatter_func_op.py +++ b/tests/st/ops/gpu/test_scatter_func_op.py @@ -775,11 +775,6 @@ def test_scatter_func_disordered_dynamic_int8(): ).astype(np.int8) np.testing.assert_array_almost_equal(output.asnumpy(), expected) - # min - output = scatter_func_d_net("min", inputx, indices, updates) - expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int8)) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @@ -817,11 +812,6 @@ def test_scatter_func_disordered_dynamic_uint8(): ).astype(np.uint8) np.testing.assert_array_almost_equal(output.asnumpy(), expected) - # min - output = scatter_func_d_net("min", inputx, indices, updates) - expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.uint8)) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training