!34131 scatter_min: add dynamic_shape case for ascend

Merge pull request !34131 from hujiahui8/scatter_min
This commit is contained in:
i-robot 2022-05-13 06:35:46 +00:00 committed by Gitee
commit 0968b28871
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
15 changed files with 346 additions and 83 deletions

View File

@ -22,7 +22,7 @@ mindspore.ops.ScatterAdd
**输入:**
- **input_x** (Parameter) - ScatterAdd的输入任意维度的Parameter。
- **indices** (Tensor) - 指定相加操作的索引数据类型为mindspore.int32。
- **indices** (Tensor) - 指定相加操作的索引数据类型为mindspore.int32或者mindspore.int64
- **updates** (Tensor) - 指定与 `input_x` 相加操作的Tensor数据类型与 `input_x` 相同shape为 `indices.shape + x.shape[1:]`
**输出:**
@ -32,6 +32,6 @@ mindspore.ops.ScatterAdd
**异常:**
- **TypeError** - `use_locking` 不是bool。
- **TypeError** - `indices` 不是int32。
- **TypeError** - `indices` 不是int32或者int64
- **ValueError** - `updates` 的shape不等于 `indices.shape + x.shape[1:]`
- **RuntimeError** - 当 `input_x``updates` 类型不一致,需要进行类型转换时,如果 `updates` 不支持转成参数 `input_x` 需要的数据类型,就会报错。

View File

@ -20,7 +20,7 @@ mindspore.ops.ScatterMax
**输入:**
- **input_x** (Parameter)- ScatterMax的输入任意维度的Parameter。
- **indices** (Tensor) - 指定最大值操作的索引数据类型必须为mindspore.int32。
- **indices** (Tensor) - 指定最大值操作的索引数据类型必须为mindspore.int32或者mindspore.int64
- **updates** (Tensor) - 指定与 `input_x` 取最大值操作的Tensor数据类型与 `input_x` 相同shape为 `indices.shape + x.shape[1:]`
@ -31,6 +31,6 @@ mindspore.ops.ScatterMax
**异常:**
- **TypeError** - `use_locking` 不是bool。
- **TypeError** - `indices` 不是int32。
- **TypeError** - `indices` 不是int32或者int64
- **ValueError** - `updates` 的shape不等于 `indices.shape + x.shape[1:]`
- **RuntimeError** - 当 `input_x``updates` 类型不一致,需要进行类型转换时,如果 `updates` 不支持转成参数 `input_x` 需要的数据类型,就会报错。

View File

@ -11,7 +11,7 @@ mindspore.ops.ScatterMin
\text{input_x}[\text{indices}[i, ..., j], :]
= min(\text{input_x}[\text{indices}[i, ..., j], :], \text{updates}[i, ..., j, :])
输入的 `input_x``updates` 遵循隐式类型转换规则,以确保数据类型一致。如果数据类型不同,则低精度数据类型将转换为高精度的数据类型。当参数的数据类型需要转换则会抛出RuntimeError异常。
输入的 `input_x``updates` 遵循隐式类型转换规则,以确保数据类型一致。如果数据类型不同,则低精度数据类型将转换为高精度的数据类型。当`updates` 不支持转成 `input_x` 需要的数据类型则会抛出RuntimeError异常。
**参数:**
@ -20,7 +20,7 @@ mindspore.ops.ScatterMin
**输入:**
- **input_x** (Parameter) - ScatterMin的输入任意维度的Parameter。
- **indices** (Tensor) - 指定最小值操作的索引数据类型必须为mindspore.int32。
- **indices** (Tensor) - 指定最小值操作的索引数据类型必须为mindspore.int32或者mindspore.int64
- **updates** (Tensor) - 指定与 `input_x` 取最小值操作的Tensor数据类型与 `input_x` 相同shape为 `indices.shape + x.shape[1:]`
**输出:**
@ -30,6 +30,6 @@ mindspore.ops.ScatterMin
**异常:**
- **TypeError** - `use_locking` 不是bool。
- **TypeError** - `indices` 不是int32。
- **TypeError** - `indices` 不是int32或者int64
- **ValueError** - `updates` 的shape不等于 `indices.shape + x.shape[1:]`
- **RuntimeError** - 当 `input_x``updates` 类型不一致,需要进行类型转换时,如果 `updates` 不支持转成参数 `input_x` 需要的数据类型,就会报错。

View File

@ -19,7 +19,7 @@
**输入:**
- **input_x** (Parameter) - ScatterSub的输入任意维度的Parameter。
- **indices** (Tensor) - 指定相减操作的索引其数据类型必须为mindspore.int32。
- **indices** (Tensor) - 指定相减操作的索引其数据类型必须为mindspore.int32或者mindspore.int64
- **updates** (Tensor) - 指定与 `input_x` 相减的Tensor其数据类型与 `input_x` 的相同shape为 `indices_shape + x_shape[1:]`
**输出:**
@ -29,6 +29,6 @@
**异常:**
- **TypeError** - `use_locking` 不是bool。
- **TypeError** - `indices` 不是int32。
- **TypeError** - `indices` 不是int32或者int64
- **ValueError** - `updates` 的shape不是 `indices_shape + x_shape[1:]`
- **RuntimeError** - 当 `input_x``updates` 类型不一致,需要进行类型转换时,如果 `updates` 不支持转成参数 `input_x` 需要的数据类型,就会报错。

View File

@ -19,7 +19,7 @@
**输入:**
- **input_x** (Parameter) - ScatterUpdate的输入任意维度的Parameter。
- **indices** (Tensor) - 指定更新操作的索引。数据类型为int32。如果索引中存在重复项则更新的顺序无法得知。
- **indices** (Tensor) - 指定更新操作的索引。数据类型为int32或者int64。如果索引中存在重复项,则更新的顺序无法得知。
- **updates** (Tensor) - 指定与 `input_x` 更新操作的Tensor其数据类型与 `input_x` 相同shape为 `indices.shape + input_x.shape[1:]`
**输出:**
@ -29,5 +29,5 @@
**异常:**
- **TypeError** - `use_locking` 不是bool。
- **TypeError** - `indices` 不是int32。
- **TypeError** - `indices` 不是int32或者int64
- **RuntimeError** - 当 `input_x``updates` 类型不一致,需要进行类型转换时,如果 `updates` 不支持转成参数 `input_x` 需要的数据类型,就会报错。

View File

@ -16,7 +16,7 @@ mindspore.ops.scatter_max
**参数:**
- **input_x** (Parameter)- ScatterMax的输入任意维度的Parameter。
- **indices** (Tensor) - 指定最大值操作的索引数据类型必须为mindspore.int32。
- **indices** (Tensor) - 指定最大值操作的索引数据类型必须为mindspore.int32或者mindspore.int64
- **updates** (Tensor) - 指定与 `input_x` 取最大值操作的Tensor数据类型与 `input_x` 相同shape为 `indices.shape + x.shape[1:]`
**输出:**
@ -26,6 +26,6 @@ mindspore.ops.scatter_max
**异常:**
- **TypeError** - `use_locking` 不是bool。
- **TypeError** - `indices` 不是int32。
- **TypeError** - `indices` 不是int32或者int64
- **ValueError** - `updates` 的shape不等于 `indices.shape + x.shape[1:]`
- **RuntimeError** - 当 `input_x``updates` 类型不一致,需要进行类型转换时,如果 `updates` 不支持转成参数 `input_x` 需要的数据类型,就会报错。

View File

@ -16,7 +16,7 @@ mindspore.ops.scatter_min
**输入:**
- **input_x** (Parameter) - scatter_min的输入任意维度的Parameter。
- **indices** (Tensor) - 指定最小值操作的索引数据类型必须为mindspore.int32。
- **indices** (Tensor) - 指定最小值操作的索引数据类型必须为mindspore.int32或者mindspore.int64
- **updates** (Tensor) - 指定与 `input_x` 取最小值操作的Tensor数据类型与 `input_x` 相同shape为 `indices.shape + input_x.shape[1:]`
**输出:**
@ -25,6 +25,6 @@ mindspore.ops.scatter_min
**异常:**
- **TypeError** - `indices` 不是int32。
- **TypeError** - `indices` 不是int32或者int64
- **ValueError** - `updates` 的shape不等于 `indices.shape + input_x.shape[1:]`
- **RuntimeError** - 当 `input_x``updates` 类型不一致,需要进行类型转换时,如果 `updates` 不支持转成参数 `input_x` 需要的数据类型,就会报错。

View File

@ -307,6 +307,20 @@ MS_REG_GPU_KERNEL_TWO(ScatterMax,
ScatterFunctorKernelMod, uint8_t, int64_t)
// ScatterMin
MS_REG_GPU_KERNEL_TWO(ScatterMin,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
ScatterFunctorKernelMod, double, int)
MS_REG_GPU_KERNEL_TWO(ScatterMin,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
ScatterFunctorKernelMod, double, int64_t)
MS_REG_GPU_KERNEL_TWO(ScatterMin,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
@ -335,6 +349,20 @@ MS_REG_GPU_KERNEL_TWO(ScatterMin,
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
ScatterFunctorKernelMod, half, int64_t)
MS_REG_GPU_KERNEL_TWO(ScatterMin,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
ScatterFunctorKernelMod, int64_t, int)
MS_REG_GPU_KERNEL_TWO(ScatterMin,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
ScatterFunctorKernelMod, int64_t, int64_t)
MS_REG_GPU_KERNEL_TWO(ScatterMin,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
@ -349,33 +377,5 @@ MS_REG_GPU_KERNEL_TWO(ScatterMin,
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
ScatterFunctorKernelMod, int, int64_t)
MS_REG_GPU_KERNEL_TWO(ScatterMin,
KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
ScatterFunctorKernelMod, int8_t, int)
MS_REG_GPU_KERNEL_TWO(ScatterMin,
KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
ScatterFunctorKernelMod, int8_t, int64_t)
MS_REG_GPU_KERNEL_TWO(ScatterMin,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
ScatterFunctorKernelMod, uint8_t, int)
MS_REG_GPU_KERNEL_TWO(ScatterMin,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
ScatterFunctorKernelMod, uint8_t, int64_t)
} // namespace kernel
} // namespace mindspore

View File

@ -128,6 +128,14 @@ template CUDA_LIB_EXPORT void ScatterFunc<half, int64_t>(enum ScatterFunctorType
const size_t &inner_size, const size_t &indices_size,
const int64_t *indices, const half *updates, half *input,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ScatterFunc<double, int>(enum ScatterFunctorType func_type, int size_limit,
const size_t &inner_size, const size_t &indices_size,
const int *indices, const double *updates, double *input,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ScatterFunc<double, int64_t>(enum ScatterFunctorType func_type, int64_t size_limit,
const size_t &inner_size, const size_t &indices_size,
const int64_t *indices, const double *updates, double *input,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ScatterFunc<int, int>(enum ScatterFunctorType func_type, int size_limit,
const size_t &inner_size, const size_t &indices_size,
const int *indices, const int *updates, int *input,
@ -136,6 +144,14 @@ template CUDA_LIB_EXPORT void ScatterFunc<int, int64_t>(enum ScatterFunctorType
const size_t &inner_size, const size_t &indices_size,
const int64_t *indices, const int *updates, int *input,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ScatterFunc<int64_t, int>(enum ScatterFunctorType func_type, int size_limit,
const size_t &inner_size, const size_t &indices_size,
const int *indices, const int64_t *updates, int64_t *input,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ScatterFunc<int64_t, int64_t>(enum ScatterFunctorType func_type, int64_t size_limit,
const size_t &inner_size, const size_t &indices_size,
const int64_t *indices, const int64_t *updates,
int64_t *input, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ScatterFunc<unsigned char, int>(enum ScatterFunctorType func_type, int size_limit,
const size_t &inner_size, const size_t &indices_size,
const int *indices, const unsigned char *updates,

View File

@ -40,7 +40,7 @@ abstract::ShapePtr ScatterMinInferShape(const PrimitivePtr &primitive, const std
}
if (indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) {
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
return input_x_shape_ptr->cast<abstract::ShapePtr>();
}
std::vector<int64_t> input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape];
@ -66,12 +66,11 @@ TypePtr ScatterMinInferType(const PrimitivePtr &primitive, const std::vector<Abs
auto indiecs_type_ptr = input_args[kInputIndex1]->BuildType();
auto updates_type_ptr = input_args[kInputIndex2]->BuildType();
auto prim_name = primitive->name();
std::set<TypePtr> type_set = {kInt32};
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indiecs_type_ptr, type_set, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x type", input_x_type_ptr, common_valid_types_with_complex,
prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("updates type", updates_type_ptr, common_valid_types_with_complex,
prim_name);
const std::set<TypePtr> indices_types = {kInt32, kInt64};
const std::set<TypePtr> valid_types = {kInt32, kInt64, kFloat16, kFloat32, kFloat64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indiecs_type_ptr, indices_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x type", input_x_type_ptr, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("updates type", updates_type_ptr, valid_types, prim_name);
std::map<std::string, TypePtr> type_dict;
type_dict.emplace("input_x", input_x_type_ptr);

View File

@ -23,14 +23,19 @@ scatter_min_op_info = TBERegOp("ScatterMin") \
.compute_cost(10) \
.kernel_name("scatter_min") \
.partial_flag(True) \
.dynamic_compile_static(True) \
.dynamic_shape(True) \
.attr("use_locking", "optional", "bool", "all") \
.input(0, "var", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.input(2, "updates", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()

View File

@ -769,7 +769,7 @@ def scatter_min(input_x, indices, updates):
Using given values to update tensor value through the min operation, along with the input indices.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
for each `i, ..., j` in `indices.shape`:
for each :math:`i, ..., j` in `indices.shape`:
.. math::
@ -781,21 +781,20 @@ def scatter_min(input_x, indices, updates):
the relatively highest priority data type.
Args:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
input_x (Parameter): The target tensor, with data type of Parameter.
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
- **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32.
- **updates** (Tensor) - The tensor doing the min operation with `input_x`,
the data type is same as `input_x`, the shape is `indices.shape + x.shape[1:]`.
indices (Tensor): The index to do min operation whose data type must be mindspore.int32 or mindspore.int64.
updates (Tensor): The tensor doing the min operation with `input_x`,
the data type is same as `input_x`, the shape is `indices.shape + input_x.shape[1:]`.
Outputs:
Tensor, the updated `input_x`, has the same shape and type as `input_x`.
Raises:
TypeError: If `indices` is not an int32.
TypeError: If `use_locking` is not a bool.
TypeError: If `indices` is not an int32 or an int64.
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
is required when data type conversion of Parameter is not supported.
ValueError: If the shape of `updates` is not equal to `indices.shape + x.shape[1:]`.
ValueError: If the shape of `updates` is not equal to `indices.shape + input_x.shape[1:]`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

View File

@ -109,7 +109,7 @@ class _ScatterOpDynamic(PrimitiveWithCheck):
self._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
def check_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32, mstype.int64], self.name)
args = {"x": x_dtype, "updates": updates_dtype}
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
@ -4160,7 +4160,7 @@ class ScatterUpdate(_ScatterOpDynamic):
Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
- **indices** (Tensor) - The index of input tensor. With int32 data type.
- **indices** (Tensor) - The index of input tensor. With int32 or int64 data type.
If there are duplicates in indices, the order for updating is undefined.
- **updates** (Tensor) - The tensor to update the input tensor, has the same type as input,
and updates.shape = indices.shape + input_x.shape[1:].
@ -4170,7 +4170,7 @@ class ScatterUpdate(_ScatterOpDynamic):
Raises:
TypeError: If `use_locking` is not a bool.
TypeError: If `indices` is not an int32.
TypeError: If `indices` is not an int32 or an int64.
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
is required when data type conversion of Parameter is not supported.
@ -4289,7 +4289,8 @@ class ScatterMax(_ScatterOpDynamic):
Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
- **indices** (Tensor) - The index to do max operation whose data type must be mindspore.int32.
- **indices** (Tensor) - The index to do max operation whose data type must be mindspore.int32 or
mindspore.int64.
- **updates** (Tensor) - The tensor that performs the maximum operation with `input_x`,
the data type is the same as `input_x`, the shape is `indices.shape + x.shape[1:]`.
@ -4298,7 +4299,7 @@ class ScatterMax(_ScatterOpDynamic):
Raises:
TypeError: If `use_locking` is not a bool.
TypeError: If `indices` is not an int32.
TypeError: If `indices` is not an int32 or an int64.
ValueError: If the shape of `updates` is not equal to `indices.shape + x.shape[1:]`.
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
is required when data type conversion of Parameter is not supported.
@ -4326,7 +4327,7 @@ class ScatterMin(_ScatterOpDynamic):
Using given values to update tensor value through the min operation, along with the input indices.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
for each `i, ..., j` in `indices.shape`:
for each :math:`i, ..., j` in `indices.shape`:
.. math::
@ -4343,17 +4344,18 @@ class ScatterMin(_ScatterOpDynamic):
Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
- **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32.
- **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32 or
mindspore.int64.
- **updates** (Tensor) - The tensor doing the min operation with `input_x`,
the data type is same as `input_x`, the shape is `indices.shape + x.shape[1:]`.
the data type is same as `input_x`, the shape is `indices.shape + input_x.shape[1:]`.
Outputs:
Tensor, the updated `input_x`, has the same shape and type as `input_x`.
Raises:
TypeError: If `use_locking` is not a bool.
TypeError: If `indices` is not an int32.
ValueError: If the shape of `updates` is not equal to `indices.shape + x.shape[1:]`.
TypeError: If `indices` is not an int32 or an int64.
ValueError: If the shape of `updates` is not equal to `indices.shape + input_x.shape[1:]`.
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
is required when data type conversion of Parameter is not supported.
@ -4401,7 +4403,8 @@ class ScatterAdd(_ScatterOpDynamic):
Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
- **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32.
- **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32 or
mindspore.int64.
- **updates** (Tensor) - The tensor doing the min operation with `input_x`,
the data type is same as `input_x`, the shape is `indices.shape + x.shape[1:]`.
@ -4410,7 +4413,7 @@ class ScatterAdd(_ScatterOpDynamic):
Raises:
TypeError: If `use_locking` is not a bool.
TypeError: If `indices` is not an int32.
TypeError: If `indices` is not an int32 or an int64.
ValueError: If the shape of `updates` is not equal to `indices.shape + x.shape[1:]`.
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
is required when data type conversion of Parameter is not supported.
@ -4511,7 +4514,8 @@ class ScatterSub(_ScatterOpDynamic):
Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
- **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32.
- **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32 or
mindspore.int64.
- **updates** (Tensor) - The tensor doing the min operation with `input_x`,
the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
@ -4520,7 +4524,7 @@ class ScatterSub(_ScatterOpDynamic):
Raises:
TypeError: If `use_locking` is not a bool.
TypeError: If `indices` is not an int32.
TypeError: If `indices` is not an int32 or an int64.
ValueError: If the shape of `updates` is not equal to `indices_shape + x_shape[1:]`.
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
is required when data type conversion of Parameter is not supported.

View File

@ -0,0 +1,250 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, Parameter, ParameterTuple
# all cases tested against dchip
class TestScatterMinNet(nn.Cell):
def __init__(self, inputx):
super(TestScatterMinNet, self).__init__()
self.scatter_min = ops.ScatterMin()
self.inputx = Parameter(inputx, name="inputx")
def construct(self, indices, updates):
out = self.scatter_min(self.inputx, indices, updates)
return out
def scatter_min_forward(nptype):
inputx = Tensor(np.arange(0, 9).reshape((3, 3)).astype(nptype))
indices = Tensor(np.array([[[1, 0, 2], [2, 2, 0]], [[1, 0, 1], [2, 1, 2]]]).astype(np.int32))
updates = Tensor(np.arange(34, 70).reshape((2, 2, 3, 3)).astype(nptype))
net = TestScatterMinNet(inputx)
output = net(indices, updates)
expected = inputx.asnumpy()
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
def scatter_min_dynamic_updates():
inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.float32))
indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32))
updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.float32))
updates_dy = Tensor(shape=(2, 2, 2, None, 4), dtype=mindspore.float32)
net = TestScatterMinNet(inputx)
net.set_inputs(indices, updates_dy)
output = net(indices, updates)
expected = np.ones((4, 2, 3, 4)).astype(np.float32)
expected[0][0][0][0] = 0.0
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
def scatter_min_dynamic_indices():
inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.int32))
indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32))
indices_dy = Tensor(shape=(2, None), dtype=mindspore.int32)
updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.int32))
net = TestScatterMinNet(inputx)
net.set_inputs(indices_dy, updates)
output = net(indices, updates)
expected = np.ones((4, 2, 3, 4)).astype(np.int32)
expected[0][0][0][0] = 0
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
class TestScatterMinGradNet(nn.Cell):
def __init__(self, network):
super(TestScatterMinGradNet, self).__init__()
self.grad = ops.GradOperation(get_all=True, sens_param=True, get_by_list=True)
self.network = network
self.params = ParameterTuple(network.trainable_params())
def construct(self, indices, updates, dout):
out = self.grad(self.network, self.params)(indices, updates, dout)
return out
def scatter_min_grad(nptype):
inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(nptype)))
indices = Tensor(np.array([[[0, 1, 2], [2, 1, 0]], [[0, 0, 0], [2, 2, 2]]]).astype(np.int32))
updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(nptype))
dout = Tensor(np.flip(np.arange(0, 12).reshape((3, 4)).astype(nptype)))
net = TestScatterMinGradNet(TestScatterMinNet(inputx))
output = net(indices, updates, dout)
indices_grad = output[0][0]
updates_grad = output[0][1]
indices_expected = np.array([[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]).astype(nptype)
updates_expected = np.array(
[
[
[
[11, 10, 9, 8], [7, 6, 5, 4], [3, 2, 1, 0]
],
[
[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]
]
],
[
[
[11, 10, 9, 8], [11, 10, 9, 8], [11, 10, 9, 8]
],
[
[3, 2, 1, 0], [3, 2, 1, 0], [3, 2, 1, 0]
]
]
]).astype(nptype)
np.testing.assert_array_almost_equal(indices_grad, indices_expected)
np.testing.assert_array_almost_equal(updates_grad, updates_expected)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_min_forward_float16():
"""
Feature: test scatter_min forward.
Description: test float16 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_min_forward(np.float16)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_min_forward(np.float16)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_min_forward_float32():
"""
Feature: test scatter_min forward.
Description: test float32 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_min_forward(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_min_forward(np.float32)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_min_forward_int32():
"""
Feature: test scatter_min forward.
Description: test int32 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_min_forward(np.int32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_min_forward(np.int32)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_min_dynamic_indices():
"""
Feature: test scatter_min dynamic shape.
Description: indices is dynamic shape.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_min_dynamic_indices()
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_min_dynamic_indices()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_min_dynamic_updates():
"""
Feature: test scatter_min dynamic shape.
Description: updates is dynamic shape.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_min_dynamic_updates()
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_min_dynamic_updates()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_min_grad_float16():
"""
Feature: test scatter_min grad.
Description: test float16 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_min_grad(np.float16)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_min_grad(np.float16)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_min_grad_float32():
"""
Feature: test scatter_min grad.
Description: test float32 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_min_grad(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_min_grad(np.float32)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_scatter_min_grad_int32():
"""
Feature: test scatter_min grad.
Description: test int32 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
scatter_min_grad(np.int32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
scatter_min_grad(np.int32)

View File

@ -775,11 +775,6 @@ def test_scatter_func_disordered_dynamic_int8():
).astype(np.int8)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# min
output = scatter_func_d_net("min", inputx, indices, updates)
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int8))
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@ -817,11 +812,6 @@ def test_scatter_func_disordered_dynamic_uint8():
).astype(np.uint8)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# min
output = scatter_func_d_net("min", inputx, indices, updates)
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.uint8))
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training